remove model selection from hybrid bayes tree
parent
37c6484cbd
commit
0b1c3688c4
|
@ -378,37 +378,4 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
||||||
}
|
}
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
GaussianBayesNetTree addGaussian(
|
|
||||||
const GaussianBayesNetTree &gbnTree,
|
|
||||||
const GaussianConditional::shared_ptr &factor) {
|
|
||||||
// If the decision tree is not initialized, then initialize it.
|
|
||||||
if (gbnTree.empty()) {
|
|
||||||
GaussianBayesNet result{factor};
|
|
||||||
return GaussianBayesNetTree(result);
|
|
||||||
} else {
|
|
||||||
auto add = [&factor](const GaussianBayesNet &graph) {
|
|
||||||
auto result = graph;
|
|
||||||
result.push_back(factor);
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
return gbnTree.apply(add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
|
||||||
const GaussianBayesNetValTree &bnTree) {
|
|
||||||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
|
||||||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
|
||||||
GaussianBayesNet gbn = gbnAndValue.first;
|
|
||||||
if (gbn.size() == 0) {
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
return gbn.logNormalizationConstant();
|
|
||||||
});
|
|
||||||
return log_norm_constants;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -262,15 +262,4 @@ struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||||
GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
|
GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
|
||||||
const GaussianConditional::shared_ptr &factor);
|
const GaussianConditional::shared_ptr &factor);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute the (logarithmic) normalization constant for each Bayes
|
|
||||||
* network in the tree.
|
|
||||||
*
|
|
||||||
* @param bnTree A tree of Bayes networks in each leaf. The tree encodes a
|
|
||||||
* discrete assignment yielding the Bayes net.
|
|
||||||
* @return AlgebraicDecisionTree<Key>
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
|
||||||
const GaussianBayesNetValTree &bnTree);
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -38,117 +38,18 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
||||||
return Base::equals(other, tol);
|
return Base::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
GaussianBayesNetTree& HybridBayesTree::addCliqueToTree(
|
|
||||||
const sharedClique& clique, GaussianBayesNetTree& result) const {
|
|
||||||
// Perform bottom-up inclusion
|
|
||||||
for (sharedClique child : clique->children) {
|
|
||||||
result = addCliqueToTree(child, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto f = clique->conditional();
|
|
||||||
|
|
||||||
if (auto hc = std::dynamic_pointer_cast<HybridConditional>(f)) {
|
|
||||||
if (auto gm = hc->asMixture()) {
|
|
||||||
result = gm->add(result);
|
|
||||||
} else if (auto g = hc->asGaussian()) {
|
|
||||||
result = addGaussian(result, g);
|
|
||||||
} else {
|
|
||||||
// Has to be discrete, which we don't add.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
GaussianBayesNetValTree HybridBayesTree::assembleTree() const {
|
|
||||||
GaussianBayesNetTree result;
|
|
||||||
for (auto&& root : roots_) {
|
|
||||||
result = addCliqueToTree(root, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet& gbn) {
|
|
||||||
return std::make_pair(gbn, 0.0);
|
|
||||||
});
|
|
||||||
return resultTree;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
AlgebraicDecisionTree<Key> HybridBayesTree::modelSelection() const {
|
|
||||||
/*
|
|
||||||
To perform model selection, we need:
|
|
||||||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
|
|
||||||
|
|
||||||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
|
||||||
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
|
|
||||||
= exp(log(q) - log(k)) = exp(-error - log(k))
|
|
||||||
= exp(-(error + log(k))),
|
|
||||||
where error is computed at the corresponding MAP point, gbt.error(mu).
|
|
||||||
|
|
||||||
So we compute (error + log(k)) and exponentiate later
|
|
||||||
*/
|
|
||||||
|
|
||||||
GaussianBayesNetValTree bnTree = assembleTree();
|
|
||||||
|
|
||||||
GaussianBayesNetValTree bn_error = bnTree.apply(
|
|
||||||
[this](const Assignment<Key>& assignment,
|
|
||||||
const std::pair<GaussianBayesNet, double>& gbnAndValue) {
|
|
||||||
// Compute the X* of each assignment
|
|
||||||
VectorValues mu = gbnAndValue.first.optimize();
|
|
||||||
|
|
||||||
// mu is empty if gbn had nullptrs
|
|
||||||
if (mu.size() == 0) {
|
|
||||||
return std::make_pair(gbnAndValue.first,
|
|
||||||
std::numeric_limits<double>::max());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the error for X* and the assignment
|
|
||||||
double error =
|
|
||||||
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
|
||||||
|
|
||||||
return std::make_pair(gbnAndValue.first, error);
|
|
||||||
});
|
|
||||||
|
|
||||||
auto trees = unzip(bn_error);
|
|
||||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
|
||||||
|
|
||||||
// Compute model selection term (with help from ADT methods)
|
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
|
||||||
|
|
||||||
// Exponentiate using our scheme
|
|
||||||
double max_log = modelSelectionTerm.max();
|
|
||||||
modelSelectionTerm = DecisionTree<Key, double>(
|
|
||||||
modelSelectionTerm,
|
|
||||||
[&max_log](const double& x) { return std::exp(x - max_log); });
|
|
||||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
|
||||||
|
|
||||||
return modelSelectionTerm;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesTree::optimize() const {
|
HybridValues HybridBayesTree::optimize() const {
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg;
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
|
||||||
// Compute model selection term
|
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm = modelSelection();
|
|
||||||
|
|
||||||
auto root = roots_.at(0);
|
auto root = roots_.at(0);
|
||||||
// Access the clique and get the underlying hybrid conditional
|
// Access the clique and get the underlying hybrid conditional
|
||||||
HybridConditional::shared_ptr root_conditional = root->conditional();
|
HybridConditional::shared_ptr root_conditional = root->conditional();
|
||||||
|
|
||||||
// Get the set of all discrete keys involved in model selection
|
|
||||||
std::set<DiscreteKey> discreteKeySet;
|
|
||||||
|
|
||||||
// The root should be discrete only, we compute the MPE
|
// The root should be discrete only, we compute the MPE
|
||||||
if (root_conditional->isDiscrete()) {
|
if (root_conditional->isDiscrete()) {
|
||||||
discrete_fg.push_back(root_conditional->asDiscrete());
|
discrete_fg.push_back(root_conditional->asDiscrete());
|
||||||
|
|
||||||
// Only add model_selection if we have discrete keys
|
|
||||||
if (discreteKeySet.size() > 0) {
|
|
||||||
discrete_fg.push_back(DecisionTreeFactor(
|
|
||||||
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
|
|
||||||
modelSelectionTerm));
|
|
||||||
}
|
|
||||||
mpe = discrete_fg.optimize();
|
mpe = discrete_fg.optimize();
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|
|
@ -89,46 +89,6 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
return HybridGaussianFactorGraph(*this).error(values);
|
return HybridGaussianFactorGraph(*this).error(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Helper function to add a clique of hybrid conditionals to the passed
|
|
||||||
* in GaussianBayesNetTree. Operates recursively on the clique in a bottom-up
|
|
||||||
* fashion, adding the children first.
|
|
||||||
*
|
|
||||||
* @param clique The
|
|
||||||
* @param result
|
|
||||||
* @return GaussianBayesNetTree&
|
|
||||||
*/
|
|
||||||
GaussianBayesNetTree& addCliqueToTree(const sharedClique& clique,
|
|
||||||
GaussianBayesNetTree& result) const;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Assemble a DecisionTree of (GaussianBayesTree, double) leaves for
|
|
||||||
* each discrete assignment.
|
|
||||||
* The included double value is used to make
|
|
||||||
* constructing the model selection term cleaner and more efficient.
|
|
||||||
*
|
|
||||||
* @return GaussianBayesNetValTree
|
|
||||||
*/
|
|
||||||
GaussianBayesNetValTree assembleTree() const;
|
|
||||||
|
|
||||||
/*
|
|
||||||
Compute L(M;Z), the likelihood of the discrete model M
|
|
||||||
given the measurements Z.
|
|
||||||
This is called the model selection term.
|
|
||||||
|
|
||||||
To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M).
|
|
||||||
|
|
||||||
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
|
|
||||||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
|
|
||||||
the joint Gaussian distribution.
|
|
||||||
|
|
||||||
This can be computed by multiplying all the exponentiated errors
|
|
||||||
of each of the conditionals.
|
|
||||||
|
|
||||||
Return a tree where each leaf value is L(M_i;Z).
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> modelSelection() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||||
* set of discrete variables and using it to compute the best continuous
|
* set of discrete variables and using it to compute the best continuous
|
||||||
|
|
|
@ -252,30 +252,16 @@ TEST(MixtureFactor, DifferentCovariances) {
|
||||||
|
|
||||||
// Check that we get different error values at the MLE point μ.
|
// Check that we get different error values at the MLE point μ.
|
||||||
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
||||||
auto cond0 = hbn->at(0)->asMixture();
|
|
||||||
auto cond1 = hbn->at(1)->asMixture();
|
|
||||||
auto discrete_cond = hbn->at(2)->asDiscrete();
|
|
||||||
|
|
||||||
HybridValues hv0(cv, DiscreteValues{{M(1), 0}});
|
HybridValues hv0(cv, DiscreteValues{{M(1), 0}});
|
||||||
HybridValues hv1(cv, DiscreteValues{{M(1), 1}});
|
HybridValues hv1(cv, DiscreteValues{{M(1), 1}});
|
||||||
AlgebraicDecisionTree<Key> expectedErrorTree(
|
|
||||||
m1,
|
auto cond0 = hbn->at(0)->asMixture();
|
||||||
cond0->error(hv0) // cond0(0)->logNormalizationConstant()
|
auto cond1 = hbn->at(1)->asMixture();
|
||||||
// - cond0(1)->logNormalizationConstant
|
auto discrete_cond = hbn->at(2)->asDiscrete();
|
||||||
+ cond1->error(hv0) + discrete_cond->error(DiscreteValues{{M(1), 0}}),
|
AlgebraicDecisionTree<Key> expectedErrorTree(m1, 9.90348755254,
|
||||||
cond0->error(hv1) // cond1(0)->logNormalizationConstant()
|
0.69314718056);
|
||||||
// - cond1(1)->logNormalizationConstant
|
|
||||||
+ cond1->error(hv1) +
|
|
||||||
discrete_cond->error(DiscreteValues{{M(1), 0}}));
|
|
||||||
EXPECT(assert_equal(expectedErrorTree, errorTree));
|
EXPECT(assert_equal(expectedErrorTree, errorTree));
|
||||||
|
|
||||||
DiscreteValues dv;
|
|
||||||
dv.insert({M(1), 1});
|
|
||||||
HybridValues expected_values(cv, dv);
|
|
||||||
|
|
||||||
HybridValues actual_values = hbn->optimize();
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expected_values, actual_values));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -329,16 +315,7 @@ TEST(MixtureFactor, DifferentMeansAndCovariances) {
|
||||||
auto prior = PriorFactor<double>(X(1), x1, prior_noise).linearize(values);
|
auto prior = PriorFactor<double>(X(1), x1, prior_noise).linearize(values);
|
||||||
mixture_fg.push_back(prior);
|
mixture_fg.push_back(prior);
|
||||||
|
|
||||||
// bn.print("BayesNet:");
|
|
||||||
// mixture_fg.print("\n\n");
|
|
||||||
|
|
||||||
VectorValues vv{{X(1), x1 * I_1x1}, {X(2), x2 * I_1x1}};
|
VectorValues vv{{X(1), x1 * I_1x1}, {X(2), x2 * I_1x1}};
|
||||||
// std::cout << "FG error for m1=0: "
|
|
||||||
// << mixture_fg.error(HybridValues(vv, DiscreteValues{{m1.first, 0}}))
|
|
||||||
// << std::endl;
|
|
||||||
// std::cout << "FG error for m1=1: "
|
|
||||||
// << mixture_fg.error(HybridValues(vv, DiscreteValues{{m1.first, 1}}))
|
|
||||||
// << std::endl;
|
|
||||||
|
|
||||||
auto hbn = mixture_fg.eliminateSequential();
|
auto hbn = mixture_fg.eliminateSequential();
|
||||||
|
|
||||||
|
@ -347,8 +324,10 @@ TEST(MixtureFactor, DifferentMeansAndCovariances) {
|
||||||
VectorValues cv;
|
VectorValues cv;
|
||||||
cv.insert(X(1), Vector1(0.0));
|
cv.insert(X(1), Vector1(0.0));
|
||||||
cv.insert(X(2), Vector1(-7.0));
|
cv.insert(X(2), Vector1(-7.0));
|
||||||
|
|
||||||
|
// The first value is chosen as the tiebreaker
|
||||||
DiscreteValues dv;
|
DiscreteValues dv;
|
||||||
dv.insert({M(1), 1});
|
dv.insert({M(1), 0});
|
||||||
HybridValues expected_values(cv, dv);
|
HybridValues expected_values(cv, dv);
|
||||||
|
|
||||||
EXPECT(assert_equal(expected_values, actual_values));
|
EXPECT(assert_equal(expected_values, actual_values));
|
||||||
|
|
Loading…
Reference in New Issue