Improved hybrid bayes net and tests

release/4.3a0
Varun Agrawal 2022-06-07 18:39:10 -04:00
parent 44079d13b4
commit 374e3cbc7a
3 changed files with 141 additions and 0 deletions

View File

@ -10,7 +10,34 @@
* @file HybridBayesNet.cpp * @file HybridBayesNet.cpp
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys. * @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal
* @date January 2022 * @date January 2022
*/ */
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
namespace gtsam {
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return boost::dynamic_pointer_cast<DiscreteConditional>(
factors_.at(i)->inner());
}
/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));
}
return gbn;
}
} // namespace gtsam

View File

@ -19,6 +19,7 @@
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/linear/GaussianBayesNet.h>
namespace gtsam { namespace gtsam {
@ -36,6 +37,27 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/** Construct empty bayes net */ /** Construct empty bayes net */
HybridBayesNet() = default; HybridBayesNet() = default;
/// Add a discrete conditional to the Bayes Net.
void add(const DiscreteKey &key, const std::string &table) {
push_back(
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
}
/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const;
/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
/**
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesNet
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -0,0 +1,92 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridBayesNet.cpp
* @brief Unit tests for HybridBayesNet
* @author Varun Agrawal
* @author Fan Jiang
* @author Frank Dellaert
* @date December 2021
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/tests/Switching.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
static const DiscreteKey Asia(0, 2);
/* ****************************************************************************/
// Test creation
TEST(HybridBayesNet, Creation) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
DiscreteConditional expected(Asia, "99/1");
CHECK(bayesNet.atDiscrete(0));
auto& df = *bayesNet.atDiscrete(0);
EXPECT(df.equals(expected));
}
/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 0;
GaussianBayesNet gbn = hybridBayesNet->choose(assignment);
EXPECT_LONGS_EQUAL(4, gbn.size());
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(0)))(assignment),
*gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(1)))(assignment),
*gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(2)))(assignment),
*gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(3)))(assignment),
*gbn.at(3)));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */