diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 2bced6c0d..33c7c91da 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -378,37 +378,4 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph( } return fg; } - -/* ************************************************************************ */ -GaussianBayesNetTree addGaussian( - const GaussianBayesNetTree &gbnTree, - const GaussianConditional::shared_ptr &factor) { - // If the decision tree is not initialized, then initialize it. - if (gbnTree.empty()) { - GaussianBayesNet result{factor}; - return GaussianBayesNetTree(result); - } else { - auto add = [&factor](const GaussianBayesNet &graph) { - auto result = graph; - result.push_back(factor); - return result; - }; - return gbnTree.apply(add); - } -} - -/* ************************************************************************* */ -AlgebraicDecisionTree computeLogNormConstants( - const GaussianBayesNetValTree &bnTree) { - AlgebraicDecisionTree log_norm_constants = DecisionTree( - bnTree, [](const std::pair &gbnAndValue) { - GaussianBayesNet gbn = gbnAndValue.first; - if (gbn.size() == 0) { - return 0.0; - } - return gbn.logNormalizationConstant(); - }); - return log_norm_constants; -} - } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 6f7b9aae7..55eaf6b5e 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -262,15 +262,4 @@ struct traits : public Testable {}; GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree, const GaussianConditional::shared_ptr &factor); -/** - * @brief Compute the (logarithmic) normalization constant for each Bayes - * network in the tree. - * - * @param bnTree A tree of Bayes networks in each leaf. The tree encodes a - * discrete assignment yielding the Bayes net. - * @return AlgebraicDecisionTree - */ -AlgebraicDecisionTree computeLogNormConstants( - const GaussianBayesNetValTree &bnTree); - } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 394c88928..f08eff01b 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -38,117 +38,18 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } -GaussianBayesNetTree& HybridBayesTree::addCliqueToTree( - const sharedClique& clique, GaussianBayesNetTree& result) const { - // Perform bottom-up inclusion - for (sharedClique child : clique->children) { - result = addCliqueToTree(child, result); - } - - auto f = clique->conditional(); - - if (auto hc = std::dynamic_pointer_cast(f)) { - if (auto gm = hc->asMixture()) { - result = gm->add(result); - } else if (auto g = hc->asGaussian()) { - result = addGaussian(result, g); - } else { - // Has to be discrete, which we don't add. - } - } - return result; -} - -/* ************************************************************************ */ -GaussianBayesNetValTree HybridBayesTree::assembleTree() const { - GaussianBayesNetTree result; - for (auto&& root : roots_) { - result = addCliqueToTree(root, result); - } - - GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet& gbn) { - return std::make_pair(gbn, 0.0); - }); - return resultTree; -} - -/* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesTree::modelSelection() const { - /* - To perform model selection, we need: - q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) - - If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) - thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k)) - = exp(log(q) - log(k)) = exp(-error - log(k)) - = exp(-(error + log(k))), - where error is computed at the corresponding MAP point, gbt.error(mu). - - So we compute (error + log(k)) and exponentiate later - */ - - GaussianBayesNetValTree bnTree = assembleTree(); - - GaussianBayesNetValTree bn_error = bnTree.apply( - [this](const Assignment& assignment, - const std::pair& gbnAndValue) { - // Compute the X* of each assignment - VectorValues mu = gbnAndValue.first.optimize(); - - // mu is empty if gbn had nullptrs - if (mu.size() == 0) { - return std::make_pair(gbnAndValue.first, - std::numeric_limits::max()); - } - - // Compute the error for X* and the assignment - double error = - this->error(HybridValues(mu, DiscreteValues(assignment))); - - return std::make_pair(gbnAndValue.first, error); - }); - - auto trees = unzip(bn_error); - AlgebraicDecisionTree errorTree = trees.second; - - // Compute model selection term (with help from ADT methods) - AlgebraicDecisionTree modelSelectionTerm = errorTree * -1; - - // Exponentiate using our scheme - double max_log = modelSelectionTerm.max(); - modelSelectionTerm = DecisionTree( - modelSelectionTerm, - [&max_log](const double& x) { return std::exp(x - max_log); }); - modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); - - return modelSelectionTerm; -} - /* ************************************************************************* */ HybridValues HybridBayesTree::optimize() const { DiscreteFactorGraph discrete_fg; DiscreteValues mpe; - // Compute model selection term - AlgebraicDecisionTree modelSelectionTerm = modelSelection(); - auto root = roots_.at(0); // Access the clique and get the underlying hybrid conditional HybridConditional::shared_ptr root_conditional = root->conditional(); - // Get the set of all discrete keys involved in model selection - std::set discreteKeySet; - // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { discrete_fg.push_back(root_conditional->asDiscrete()); - - // Only add model_selection if we have discrete keys - if (discreteKeySet.size() > 0) { - discrete_fg.push_back(DecisionTreeFactor( - DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), - modelSelectionTerm)); - } mpe = discrete_fg.optimize(); } else { throw std::runtime_error( diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 8327b7f31..af8eb3228 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -89,46 +89,6 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { return HybridGaussianFactorGraph(*this).error(values); } - /** - * @brief Helper function to add a clique of hybrid conditionals to the passed - * in GaussianBayesNetTree. Operates recursively on the clique in a bottom-up - * fashion, adding the children first. - * - * @param clique The - * @param result - * @return GaussianBayesNetTree& - */ - GaussianBayesNetTree& addCliqueToTree(const sharedClique& clique, - GaussianBayesNetTree& result) const; - - /** - * @brief Assemble a DecisionTree of (GaussianBayesTree, double) leaves for - * each discrete assignment. - * The included double value is used to make - * constructing the model selection term cleaner and more efficient. - * - * @return GaussianBayesNetValTree - */ - GaussianBayesNetValTree assembleTree() const; - - /* - Compute L(M;Z), the likelihood of the discrete model M - given the measurements Z. - This is called the model selection term. - - To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M). - - By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), - hence L(X;M,Z)P(X|M) is the unnormalized probabilty of - the joint Gaussian distribution. - - This can be computed by multiplying all the exponentiated errors - of each of the conditionals. - - Return a tree where each leaf value is L(M_i;Z). - */ - AlgebraicDecisionTree modelSelection() const; - /** * @brief Optimize the hybrid Bayes tree by computing the MPE for the current * set of discrete variables and using it to compute the best continuous diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 9de277214..1bd4f5b88 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -252,30 +252,16 @@ TEST(MixtureFactor, DifferentCovariances) { // Check that we get different error values at the MLE point μ. AlgebraicDecisionTree errorTree = hbn->errorTree(cv); - auto cond0 = hbn->at(0)->asMixture(); - auto cond1 = hbn->at(1)->asMixture(); - auto discrete_cond = hbn->at(2)->asDiscrete(); HybridValues hv0(cv, DiscreteValues{{M(1), 0}}); HybridValues hv1(cv, DiscreteValues{{M(1), 1}}); - AlgebraicDecisionTree expectedErrorTree( - m1, - cond0->error(hv0) // cond0(0)->logNormalizationConstant() - // - cond0(1)->logNormalizationConstant - + cond1->error(hv0) + discrete_cond->error(DiscreteValues{{M(1), 0}}), - cond0->error(hv1) // cond1(0)->logNormalizationConstant() - // - cond1(1)->logNormalizationConstant - + cond1->error(hv1) + - discrete_cond->error(DiscreteValues{{M(1), 0}})); + + auto cond0 = hbn->at(0)->asMixture(); + auto cond1 = hbn->at(1)->asMixture(); + auto discrete_cond = hbn->at(2)->asDiscrete(); + AlgebraicDecisionTree expectedErrorTree(m1, 9.90348755254, + 0.69314718056); EXPECT(assert_equal(expectedErrorTree, errorTree)); - - DiscreteValues dv; - dv.insert({M(1), 1}); - HybridValues expected_values(cv, dv); - - HybridValues actual_values = hbn->optimize(); - - EXPECT(assert_equal(expected_values, actual_values)); } /* ************************************************************************* */ @@ -329,16 +315,7 @@ TEST(MixtureFactor, DifferentMeansAndCovariances) { auto prior = PriorFactor(X(1), x1, prior_noise).linearize(values); mixture_fg.push_back(prior); - // bn.print("BayesNet:"); - // mixture_fg.print("\n\n"); - VectorValues vv{{X(1), x1 * I_1x1}, {X(2), x2 * I_1x1}}; - // std::cout << "FG error for m1=0: " - // << mixture_fg.error(HybridValues(vv, DiscreteValues{{m1.first, 0}})) - // << std::endl; - // std::cout << "FG error for m1=1: " - // << mixture_fg.error(HybridValues(vv, DiscreteValues{{m1.first, 1}})) - // << std::endl; auto hbn = mixture_fg.eliminateSequential(); @@ -347,8 +324,10 @@ TEST(MixtureFactor, DifferentMeansAndCovariances) { VectorValues cv; cv.insert(X(1), Vector1(0.0)); cv.insert(X(2), Vector1(-7.0)); + + // The first value is chosen as the tiebreaker DiscreteValues dv; - dv.insert({M(1), 1}); + dv.insert({M(1), 0}); HybridValues expected_values(cv, dv); EXPECT(assert_equal(expected_values, actual_values));