Improved hybrid bayes net and tests
parent
44079d13b4
commit
374e3cbc7a
|
|
@ -10,7 +10,34 @@
|
||||||
* @file HybridBayesNet.cpp
|
* @file HybridBayesNet.cpp
|
||||||
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
* @date January 2022
|
* @date January 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||||
|
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||||
|
return boost::dynamic_pointer_cast<DiscreteConditional>(
|
||||||
|
factors_.at(i)->inner());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
GaussianBayesNet HybridBayesNet::choose(
|
||||||
|
const DiscreteValues &assignment) const {
|
||||||
|
GaussianBayesNet gbn;
|
||||||
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
GaussianMixture gm = *this->atGaussian(idx);
|
||||||
|
gbn.push_back(gm(assignment));
|
||||||
|
}
|
||||||
|
return gbn;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -36,6 +37,27 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/** Construct empty bayes net */
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() = default;
|
||||||
|
|
||||||
|
/// Add a discrete conditional to the Bayes Net.
|
||||||
|
void add(const DiscreteKey &key, const std::string &table) {
|
||||||
|
push_back(
|
||||||
|
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a specific Gaussian mixture by index `i`.
|
||||||
|
GaussianMixture::shared_ptr atGaussian(size_t i) const;
|
||||||
|
|
||||||
|
/// Get a specific discrete conditional by index `i`.
|
||||||
|
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
||||||
|
* value assignment.
|
||||||
|
*
|
||||||
|
* @param assignment The discrete value assignment for the discrete keys.
|
||||||
|
* @return GaussianBayesNet
|
||||||
|
*/
|
||||||
|
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testHybridBayesNet.cpp
|
||||||
|
* @brief Unit tests for HybridBayesNet
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @author Fan Jiang
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date December 2021
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/tests/Switching.h>
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
using noiseModel::Isotropic;
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
|
||||||
|
static const DiscreteKey Asia(0, 2);
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test creation
|
||||||
|
TEST(HybridBayesNet, Creation) {
|
||||||
|
HybridBayesNet bayesNet;
|
||||||
|
|
||||||
|
bayesNet.add(Asia, "99/1");
|
||||||
|
|
||||||
|
DiscreteConditional expected(Asia, "99/1");
|
||||||
|
|
||||||
|
CHECK(bayesNet.atDiscrete(0));
|
||||||
|
auto& df = *bayesNet.atDiscrete(0);
|
||||||
|
EXPECT(df.equals(expected));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test choosing an assignment of conditionals
|
||||||
|
TEST(HybridBayesNet, Choose) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
Ordering ordering;
|
||||||
|
for (auto&& kvp : s.linearizationPoint) {
|
||||||
|
ordering += kvp.key;
|
||||||
|
}
|
||||||
|
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||||
|
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||||
|
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||||
|
|
||||||
|
DiscreteValues assignment;
|
||||||
|
assignment[M(1)] = 1;
|
||||||
|
assignment[M(2)] = 1;
|
||||||
|
assignment[M(3)] = 0;
|
||||||
|
|
||||||
|
GaussianBayesNet gbn = hybridBayesNet->choose(assignment);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(4, gbn.size());
|
||||||
|
|
||||||
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
|
hybridBayesNet->atGaussian(0)))(assignment),
|
||||||
|
*gbn.at(0)));
|
||||||
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
|
hybridBayesNet->atGaussian(1)))(assignment),
|
||||||
|
*gbn.at(1)));
|
||||||
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
|
hybridBayesNet->atGaussian(2)))(assignment),
|
||||||
|
*gbn.at(2)));
|
||||||
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
|
hybridBayesNet->atGaussian(3)))(assignment),
|
||||||
|
*gbn.at(3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
Loading…
Reference in New Issue