From 5da56c139310a9701ea330ed95fce1e7de1fb135 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 21 Dec 2022 20:19:51 +0530 Subject: [PATCH] add choose method to HybridBayesTree --- gtsam/hybrid/HybridBayesTree.cpp | 14 +++++- gtsam/hybrid/HybridBayesTree.h | 10 +++++ gtsam/hybrid/tests/testHybridBayesTree.cpp | 51 ++++++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index b706fb745..4ab344a1d 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -138,7 +138,8 @@ struct HybridAssignmentData { /* ************************************************************************* */ -VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { +GaussianBayesTree HybridBayesTree::choose( + const DiscreteValues& assignment) const { GaussianBayesTree gbt; HybridAssignmentData rootData(assignment, 0, &gbt); { @@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { } if (!rootData.isValid()) { + return GaussianBayesTree(); + } + return gbt; +} + +/* ************************************************************************* + */ +VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { + GaussianBayesTree gbt = this->choose(assignment); + // If empty GaussianBayesTree, means a clique is pruned hence invalid + if (gbt.size() == 0) { return VectorValues(); } VectorValues result = gbt.optimize(); diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 2d01aab76..628a453a6 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; + /** + * @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete + * value assignment. + * + * @param assignment The discrete value assignment for the discrete keys. + * @return GaussianBayesTree + */ + GaussianBayesTree choose(const DiscreteValues& assignment) const; + /** * @brief Optimize the hybrid Bayes tree by computing the MPE for the current * set of discrete variables and using it to compute the best continuous diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 3992aa023..b4d049210 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous())); } +/* ****************************************************************************/ +// Test for choosing a GaussianBayesTree from a HybridBayesTree. +TEST(HybridBayesTree, Choose) { + Switching s(4); + + HybridGaussianISAM isam; + HybridGaussianFactorGraph graph1; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(0), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(s.linearizedFactorGraph.at(0)); + for (size_t i = 4; i <= 6; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the discrete factors + for (size_t i = 7; i <= 9; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + isam.update(graph1); + + DiscreteValues assignment; + assignment[M(0)] = 1; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + + GaussianBayesTree gbt = isam.choose(assignment); + + Ordering ordering; + ordering += X(0); + ordering += X(1); + ordering += X(2); + ordering += X(3); + ordering += M(0); + ordering += M(1); + ordering += M(2); + + //TODO(Varun) get segfault if ordering not provided + auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering); + + auto expected_gbt = bayesTree->choose(assignment); + + EXPECT(assert_equal(expected_gbt, gbt)); +} + /* ****************************************************************************/ // Test HybridBayesTree serialization. TEST(HybridBayesTree, Serialization) {