model selection for HybridBayesTree
parent
e7cb7b2dcd
commit
4b2a22eaa5
|
@ -38,19 +38,116 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
|||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
GaussianBayesNetTree& HybridBayesTree::addCliqueToTree(
|
||||
const sharedClique& clique, GaussianBayesNetTree& result) const {
|
||||
// Perform bottom-up inclusion
|
||||
for (sharedClique child : clique->children) {
|
||||
result = addCliqueToTree(child, result);
|
||||
}
|
||||
|
||||
auto f = clique->conditional();
|
||||
|
||||
if (auto hc = std::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
if (auto gm = hc->asMixture()) {
|
||||
result = gm->add(result);
|
||||
} else if (auto g = hc->asGaussian()) {
|
||||
result = addGaussian(result, g);
|
||||
} else {
|
||||
// Has to be discrete, which we don't add.
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianBayesNetValTree HybridBayesTree::assembleTree() const {
|
||||
GaussianBayesNetTree result;
|
||||
for (auto&& root : roots_) {
|
||||
result = addCliqueToTree(root, result);
|
||||
}
|
||||
|
||||
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet& gbn) {
|
||||
return std::make_pair(gbn, 0.0);
|
||||
});
|
||||
return resultTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
|
||||
/*
|
||||
To perform model selection, we need:
|
||||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
|
||||
|
||||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
|
||||
= exp(log(q) - log(k)) = exp(-error - log(k))
|
||||
= exp(-(error + log(k))),
|
||||
where error is computed at the corresponding MAP point, gbt.error(mu).
|
||||
|
||||
So we compute (error + log(k)) and exponentiate later
|
||||
*/
|
||||
|
||||
GaussianBayesNetValTree bnTree = assembleTree();
|
||||
|
||||
GaussianBayesNetValTree bn_error = bnTree.apply(
|
||||
[this](const Assignment<Key>& assignment,
|
||||
const std::pair<GaussianBayesNet, double>& gbnAndValue) {
|
||||
// Compute the X* of each assignment
|
||||
VectorValues mu = gbnAndValue.first.optimize();
|
||||
|
||||
// mu is empty if gbn had nullptrs
|
||||
if (mu.size() == 0) {
|
||||
return std::make_pair(gbnAndValue.first,
|
||||
std::numeric_limits<double>::max());
|
||||
}
|
||||
|
||||
// Compute the error for X* and the assignment
|
||||
double error =
|
||||
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||
|
||||
return std::make_pair(gbnAndValue.first, error);
|
||||
});
|
||||
|
||||
auto trees = unzip(bn_error);
|
||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||
|
||||
// Only compute logNormalizationConstant
|
||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
||||
computeLogNormConstants(bnTree);
|
||||
|
||||
// Compute model selection term (with help from ADT methods)
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||
computeModelSelectionTerm(errorTree, log_norm_constants);
|
||||
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesTree::optimize() const {
|
||||
DiscreteBayesNet dbn;
|
||||
DiscreteFactorGraph discrete_fg;
|
||||
DiscreteValues mpe;
|
||||
|
||||
// Compute model selection term
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm = modelSelection();
|
||||
|
||||
auto root = roots_.at(0);
|
||||
// Access the clique and get the underlying hybrid conditional
|
||||
HybridConditional::shared_ptr root_conditional = root->conditional();
|
||||
|
||||
// Get the set of all discrete keys involved in model selection
|
||||
std::set<DiscreteKey> discreteKeySet;
|
||||
|
||||
// The root should be discrete only, we compute the MPE
|
||||
if (root_conditional->isDiscrete()) {
|
||||
dbn.push_back(root_conditional->asDiscrete());
|
||||
mpe = DiscreteFactorGraph(dbn).optimize();
|
||||
discrete_fg.push_back(root_conditional->asDiscrete());
|
||||
|
||||
// Only add model_selection if we have discrete keys
|
||||
if (discreteKeySet.size() > 0) {
|
||||
discrete_fg.push_back(DecisionTreeFactor(
|
||||
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
|
||||
modelSelectionTerm));
|
||||
}
|
||||
mpe = discrete_fg.optimize();
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridBayesTree root is not discrete-only. Please check elimination "
|
||||
|
|
|
@ -84,6 +84,51 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
*/
|
||||
GaussianBayesTree choose(const DiscreteValues& assignment) const;
|
||||
|
||||
/** Error for all conditionals. */
|
||||
double error(const HybridValues& values) const {
|
||||
return HybridGaussianFactorGraph(*this).error(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to add a clique of hybrid conditionals to the passed
|
||||
* in GaussianBayesNetTree. Operates recursively on the clique in a bottom-up
|
||||
* fashion, adding the children first.
|
||||
*
|
||||
* @param clique The
|
||||
* @param result
|
||||
* @return GaussianBayesNetTree&
|
||||
*/
|
||||
GaussianBayesNetTree& addCliqueToTree(const sharedClique& clique,
|
||||
GaussianBayesNetTree& result) const;
|
||||
|
||||
/**
|
||||
* @brief Assemble a DecisionTree of (GaussianBayesTree, double) leaves for
|
||||
* each discrete assignment.
|
||||
* The included double value is used to make
|
||||
* constructing the model selection term cleaner and more efficient.
|
||||
*
|
||||
* @return GaussianBayesNetValTree
|
||||
*/
|
||||
GaussianBayesNetValTree assembleTree() const;
|
||||
|
||||
/*
|
||||
Compute L(M;Z), the likelihood of the discrete model M
|
||||
given the measurements Z.
|
||||
This is called the model selection term.
|
||||
|
||||
To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M).
|
||||
|
||||
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
|
||||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
|
||||
the joint Gaussian distribution.
|
||||
|
||||
This can be computed by multiplying all the exponentiated errors
|
||||
of each of the conditionals.
|
||||
|
||||
Return a tree where each leaf value is L(M_i;Z).
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> modelSelection() const;
|
||||
|
||||
/**
|
||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||
* set of discrete variables and using it to compute the best continuous
|
||||
|
|
Loading…
Reference in New Issue