From cb2d2e678d4d0fdd33437895865b43a56af7fa0c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 26 Aug 2022 11:36:13 -0400 Subject: [PATCH] HybridBayesTree::optimize --- gtsam/hybrid/HybridBayesTree.cpp | 37 +++++++ gtsam/hybrid/HybridBayesTree.h | 9 ++ gtsam/hybrid/tests/testHybridBayesTree.cpp | 106 +++++++++++++++++++++ 3 files changed, 152 insertions(+) create mode 100644 gtsam/hybrid/tests/testHybridBayesTree.cpp diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index d65270f91..30fe6f168 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -35,4 +35,41 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { + GaussianBayesNet gbn; + + KeyVector added_keys; + + // Iterate over all the nodes in the BayesTree + for (auto&& node : nodes()) { + // Check if conditional being added is already in the Bayes net. + if (std::find(added_keys.begin(), added_keys.end(), node.first) == + added_keys.end()) { + // Access the clique and get the underlying hybrid conditional + HybridBayesTreeClique::shared_ptr clique = node.second; + HybridConditional::shared_ptr conditional = clique->conditional(); + + KeyVector frontals(conditional->frontals().begin(), + conditional->frontals().end()); + + // Record the key being added + added_keys.insert(added_keys.end(), frontals.begin(), frontals.end()); + + // If conditional is hybrid (and not discrete-only), we get the Gaussian + // Conditional corresponding to the assignment and add it to the Gaussian + // Bayes Net. + if (conditional->isHybrid()) { + auto gm = conditional->asMixture(); + GaussianConditional::shared_ptr gaussian_conditional = + (*gm)(assignment); + + gbn.push_back(gaussian_conditional); + } + } + } + // Return the optimized bayes net. + return gbn.optimize(); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 02a4a11e5..b7950a483 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /** Check equality */ bool equals(const This& other, double tol = 1e-9) const; + /** + * @brief Recursively optimize the BayesTree to produce a vector solution. + * + * @param assignment The discrete values assignment to select the Gaussian + * mixtures. + * @return VectorValues + */ + VectorValues optimize(const DiscreteValues& assignment) const; + /// @} }; diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp new file mode 100644 index 000000000..ddc704460 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -0,0 +1,106 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridBayesTree.cpp + * @brief Unit tests for HybridBayesTree + * @author Varun Agrawal + * @date August 2022 + */ + +#include +#include + +#include "Switching.h" + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +/* ****************************************************************************/ +// Test for optimizing a HybridBayesTree. +TEST(HybridBayesTree, Optimize) { + Switching s(4); + + HybridGaussianISAM isam; + HybridGaussianFactorGraph graph1; + + // Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4 + for (size_t i = 1; i < 4; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + // Add the Gaussian factors, 1 prior on X(1), + // 3 measurements on X(2), X(3), X(4) + graph1.push_back(s.linearizedFactorGraph.at(0)); + for (size_t i = 4; i <= 7; i++) { + graph1.push_back(s.linearizedFactorGraph.at(i)); + } + + isam.update(graph1); + + DiscreteValues assignment; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + assignment[M(3)] = 1; + + VectorValues delta = isam.optimize(assignment); + + // The linearization point has the same value as the key index, + // e.g. X(1) = 1, X(2) = 2, + // but the factors specify X(k) = k-1, so delta should be -1. + VectorValues expected_delta; + expected_delta.insert(make_pair(X(1), -Vector1::Ones())); + expected_delta.insert(make_pair(X(2), -Vector1::Ones())); + expected_delta.insert(make_pair(X(3), -Vector1::Ones())); + expected_delta.insert(make_pair(X(4), -Vector1::Ones())); + + EXPECT(assert_equal(expected_delta, delta)); + + // Create ordering. + Ordering ordering; + for (size_t k = 1; k <= s.K; k++) ordering += X(k); + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + hybridBayesNet->print(); + GaussianBayesNet gbn = hybridBayesNet->choose(assignment); + + // EXPECT_LONGS_EQUAL(4, gbn.size()); + + // EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + // hybridBayesNet->atGaussian(0)))(assignment), + // *gbn.at(0))); + // EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + // hybridBayesNet->atGaussian(1)))(assignment), + // *gbn.at(1))); + // EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + // hybridBayesNet->atGaussian(2)))(assignment), + // *gbn.at(2))); + // EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + // hybridBayesNet->atGaussian(3)))(assignment), + // *gbn.at(3))); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */