diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 30fe6f168..e9c15b0ed 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -18,6 +18,8 @@ */ #include +#include +#include #include #include #include @@ -35,6 +37,42 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +HybridValues HybridBayesTree::optimize() const { + HybridBayesNet hbn; + DiscreteBayesNet dbn; + + KeyVector added_keys; + + // Iterate over all the nodes in the BayesTree + for (auto&& node : nodes()) { + // Check if conditional being added is already in the Bayes net. + if (std::find(added_keys.begin(), added_keys.end(), node.first) == + added_keys.end()) { + // Access the clique and get the underlying hybrid conditional + HybridBayesTreeClique::shared_ptr clique = node.second; + HybridConditional::shared_ptr conditional = clique->conditional(); + + // Record the key being added + added_keys.insert(added_keys.end(), conditional->frontals().begin(), + conditional->frontals().end()); + + if (conditional->isHybrid()) { + // If conditional is hybrid, add it to a Hybrid Bayes net. + hbn.push_back(conditional); + } else if (conditional->isDiscrete()) { + // Else if discrete, we use it to compute the MPE + dbn.push_back(conditional->asDiscreteConditional()); + } + } + } + // Get the MPE + DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize(); + // Given the MPE, compute the optimal continuous values. + GaussianBayesNet gbn = hbn.choose(mpe); + return HybridValues(mpe, gbn.optimize()); +} + /* ************************************************************************* */ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { GaussianBayesNet gbn; @@ -50,11 +88,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { HybridBayesTreeClique::shared_ptr clique = node.second; HybridConditional::shared_ptr conditional = clique->conditional(); - KeyVector frontals(conditional->frontals().begin(), - conditional->frontals().end()); - // Record the key being added - added_keys.insert(added_keys.end(), frontals.begin(), frontals.end()); + added_keys.insert(added_keys.end(), conditional->frontals().begin(), + conditional->frontals().end()); // If conditional is hybrid (and not discrete-only), we get the Gaussian // Conditional corresponding to the assignment and add it to the Gaussian diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index b7950a483..361fbe86f 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; + /** + * @brief Optimize the hybrid Bayes tree by computing the MPE for the current + * set of discrete variables and using it to compute the best continuous + * update delta. + * + * @return HybridValues + */ + HybridValues optimize() const; + /** * @brief Recursively optimize the BayesTree to produce a vector solution. * diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index bcaf7e599..56f3680ae 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -125,11 +125,6 @@ TEST(HybridBayesNet, OptimizeAssignment) { TEST(HybridBayesNet, Optimize) { Switching s(4); - Ordering ordering; - for (auto&& kvp : s.linearizationPoint) { - ordering += kvp.key; - } - Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); HybridBayesNet::shared_ptr hybridBayesNet = s.linearizedFactorGraph.eliminateSequential(hybridOrdering); diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index a693af1b2..3584d7db9 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -16,6 +16,7 @@ * @date August 2022 */ +#include #include #include @@ -31,8 +32,8 @@ using symbol_shorthand::M; using symbol_shorthand::X; /* ****************************************************************************/ -// Test for optimizing a HybridBayesTree. -TEST(HybridBayesTree, Optimize) { +// Test for optimizing a HybridBayesTree with a given assignment. +TEST(HybridBayesTree, OptimizeAssignment) { Switching s(4); HybridGaussianISAM isam; @@ -85,6 +86,58 @@ TEST(HybridBayesTree, Optimize) { EXPECT(assert_equal(expected, delta)); } +/* ****************************************************************************/ +// Test for optimizing a HybridBayesTree. +TEST(HybridBayesTree, Optimize) { + Switching s(4); + + HybridGaussianISAM isam; + HybridGaussianFactorGraph graph1; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(1), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(s.linearizedFactorGraph.at(0)); + for (size_t i = 4; i <= 6; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the discrete factors + for (size_t i = 7; i <= 9; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + isam.update(graph1); + + HybridValues delta = isam.optimize(); + + // Create ordering. + Ordering ordering; + for (size_t k = 1; k <= s.K; k++) ordering += X(k); + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + DiscreteFactorGraph dfg; + for (auto&& f : *remainingFactorGraph) { + auto factor = dynamic_pointer_cast(f); + dfg.push_back( + boost::dynamic_pointer_cast(factor->inner())); + } + + DiscreteValues expectedMPE = dfg.optimize(); + VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE); + + EXPECT(assert_equal(expectedMPE, delta.discrete())); + EXPECT(assert_equal(expectedValues, delta.continuous())); +} + /* ************************************************************************* */ int main() { TestResult tr;