diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 1292711d8..b3df73bf2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -10,7 +10,34 @@ * @file HybridBayesNet.cpp * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @author Fan Jiang + * @author Varun Agrawal * @date January 2022 */ #include + +namespace gtsam { + +/* ************************************************************************* */ +GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { + return boost::dynamic_pointer_cast(factors_.at(i)->inner()); +} + +/* ************************************************************************* */ +DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { + return boost::dynamic_pointer_cast( + factors_.at(i)->inner()); +} + +/* ************************************************************************* */ +GaussianBayesNet HybridBayesNet::choose( + const DiscreteValues &assignment) const { + GaussianBayesNet gbn; + for (size_t idx = 0; idx < size(); idx++) { + GaussianMixture gm = *this->atGaussian(idx); + gbn.push_back(gm(assignment)); + } + return gbn; +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 43eead280..412b208b9 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -19,6 +19,7 @@ #include #include +#include namespace gtsam { @@ -36,6 +37,27 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /** Construct empty bayes net */ HybridBayesNet() = default; + + /// Add a discrete conditional to the Bayes Net. + void add(const DiscreteKey &key, const std::string &table) { + push_back( + HybridConditional(boost::make_shared(key, table))); + } + + /// Get a specific Gaussian mixture by index `i`. + GaussianMixture::shared_ptr atGaussian(size_t i) const; + + /// Get a specific discrete conditional by index `i`. + DiscreteConditional::shared_ptr atDiscrete(size_t i) const; + + /** + * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete + * value assignment. + * + * @param assignment The discrete value assignment for the discrete keys. + * @return GaussianBayesNet + */ + GaussianBayesNet choose(const DiscreteValues &assignment) const; }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp new file mode 100644 index 000000000..34133ab0b --- /dev/null +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -0,0 +1,92 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridBayesNet.cpp + * @brief Unit tests for HybridBayesNet + * @author Varun Agrawal + * @author Fan Jiang + * @author Frank Dellaert + * @date December 2021 + */ + +#include +#include + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +static const DiscreteKey Asia(0, 2); + +/* ****************************************************************************/ +// Test creation +TEST(HybridBayesNet, Creation) { + HybridBayesNet bayesNet; + + bayesNet.add(Asia, "99/1"); + + DiscreteConditional expected(Asia, "99/1"); + + CHECK(bayesNet.atDiscrete(0)); + auto& df = *bayesNet.atDiscrete(0); + EXPECT(df.equals(expected)); +} + +/* ****************************************************************************/ +// Test choosing an assignment of conditionals +TEST(HybridBayesNet, Choose) { + Switching s(4); + + Ordering ordering; + for (auto&& kvp : s.linearizationPoint) { + ordering += kvp.key; + } + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + DiscreteValues assignment; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + assignment[M(3)] = 0; + + GaussianBayesNet gbn = hybridBayesNet->choose(assignment); + + EXPECT_LONGS_EQUAL(4, gbn.size()); + + EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + hybridBayesNet->atGaussian(0)))(assignment), + *gbn.at(0))); + EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + hybridBayesNet->atGaussian(1)))(assignment), + *gbn.at(1))); + EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + hybridBayesNet->atGaussian(2)))(assignment), + *gbn.at(2))); + EXPECT(assert_equal(*(*boost::dynamic_pointer_cast( + hybridBayesNet->atGaussian(3)))(assignment), + *gbn.at(3))); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file