diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 0ff4e342b..027bd75d4 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -248,7 +248,7 @@ static GaussianBayesNetTree addGaussian( } /* ************************************************************************ */ -GaussianBayesNetTree HybridBayesNet::assembleTree() const { +GaussianBayesNetValTree HybridBayesNet::assembleTree() const { GaussianBayesNetTree result; for (auto &f : factors_) { @@ -276,23 +276,17 @@ GaussianBayesNetTree HybridBayesNet::assembleTree() const { } } - return result; + GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) { + return std::make_pair(gbn, 0.0); + }); + return resultTree; } /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Collect all the discrete factors to compute MPE DiscreteFactorGraph discrete_fg; - VectorValues continuousValues; - std::set discreteKeySet; - - // this->print(); - GaussianBayesNetTree bnTree = assembleTree(); - // bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) { - // gbn.print(); - // return ""; - // }); /* Perform the integration of L(X;M,Z)P(X|M) which is the model selection term. @@ -316,43 +310,35 @@ HybridValues HybridBayesNet::optimize() const { So we compute (error + log(k)) and exponentiate later */ - // Compute the X* of each assignment and use that as the MAP. - DecisionTree x_map( - bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); }); - // Only compute logNormalizationConstant for now - AlgebraicDecisionTree log_norm_constants = - DecisionTree(bnTree, [](const GaussianBayesNet &gbn) { + std::set discreteKeySet; + GaussianBayesNetValTree bnTree = assembleTree(); + + GaussianBayesNetValTree bn_error = bnTree.apply( + [this](const Assignment &assignment, + const std::pair &gbnAndValue) { + // Compute the X* of each assignment + VectorValues mu = gbnAndValue.first.optimize(); + // Compute the error for X* and the assignment + double error = + this->error(HybridValues(mu, DiscreteValues(assignment))); + + return std::make_pair(gbnAndValue.first, error); + }); + + auto trees = unzip(bn_error); + AlgebraicDecisionTree errorTree = trees.second; + + // Only compute logNormalizationConstant + AlgebraicDecisionTree log_norm_constants = DecisionTree( + bnTree, [](const std::pair &gbnAndValue) { + GaussianBayesNet gbn = gbnAndValue.first; if (gbn.size() == 0) { return 0.0; } return gbn.logNormalizationConstant(); }); - // Compute errors as VectorValues - DecisionTree errorVectors = x_map.apply( - [this](const Assignment &assignment, const VectorValues &mu) { - double error = 0.0; - for (auto &&f : *this) { - if (auto gm = dynamic_pointer_cast(f)) { - error += gm->error(HybridValues(mu, DiscreteValues(assignment))); - - } else if (auto hc = dynamic_pointer_cast(f)) { - if (auto gm = hc->asMixture()) { - error += gm->error(HybridValues(mu, DiscreteValues(assignment))); - - } else if (auto g = hc->asGaussian()) { - error += g->error(mu); - } - } - } - VectorValues e; - e.insert(0, Vector1(error)); - return e; - }); - AlgebraicDecisionTree errorTree = DecisionTree( - errorVectors, [](const VectorValues &v) { return v[0](0); }); - // Compute model selection term (with help from ADT methods) AlgebraicDecisionTree model_selection_term = (errorTree + log_norm_constants) * -1; diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 22e03bba9..f0cdbaaf9 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -118,6 +118,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { return evaluate(values); } + GaussianBayesNetValTree assembleTree() const; + /** * @brief Solve the HybridBayesNet by first computing the MPE of all the * discrete variables and then optimizing the continuous variables based on diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index afd1c8032..8828a9172 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -33,6 +33,14 @@ class HybridValues; /// Alias for DecisionTree of GaussianFactorGraphs using GaussianFactorGraphTree = DecisionTree; +/// Alias for DecisionTree of GaussianBayesNets +using GaussianBayesNetTree = DecisionTree; +/** + * Alias for DecisionTree of (GaussianBayesNet, double) pairs. + * Used for model selection in BayesNet::optimize + */ +using GaussianBayesNetValTree = + DecisionTree>; KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);