Merge pull request #1372 from borglab/hybrid/simplifiedAPI

Simplified AP for HybridBayesNet
release/4.3a0
Frank Dellaert 2023-01-05 19:27:16 -08:00 committed by GitHub
commit a3b177c604
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 119 additions and 171 deletions

View File

@ -42,13 +42,21 @@ const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
return conditionals_; return conditionals_;
} }
/* *******************************************************************************/
GaussianMixture::GaussianMixture(
KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::GaussianMixture( GaussianMixture::GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) const std::vector<GaussianConditional::shared_ptr> &conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents, : GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionalsList)) {} Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::add( GaussianFactorGraphTree GaussianMixture::add(

View File

@ -104,6 +104,18 @@ class GTSAM_EXPORT GaussianMixture
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const Conditionals &conditionals); const Conditionals &conditionals);
/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
GaussianMixture(KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals);
/** /**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals * @brief Make a Gaussian Mixture from a list of Gaussian conditionals
* *

View File

@ -197,8 +197,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedGaussianMixture->prune(*decisionTree); // imperative :-( prunedGaussianMixture->prune(*decisionTree); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(prunedGaussianMixture);
boost::make_shared<HybridConditional>(prunedGaussianMixture));
} else { } else {
// Add the non-GaussianMixture conditional // Add the non-GaussianMixture conditional
@ -209,21 +208,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
return prunedBayesNetFragment; return prunedBayesNetFragment;
} }
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return at(i)->asMixture();
}
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return at(i)->asGaussian();
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return at(i)->asDiscrete();
}
/* ************************************************************************* */ /* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose( GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {

View File

@ -63,55 +63,26 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @{ /// @{
/// Add HybridConditional to Bayes Net /// Add HybridConditional to Bayes Net
using Base::add; using Base::emplace_shared;
/// Add a Gaussian Mixture to the Bayes Net. /// Add a conditional directly using a pointer.
void addMixture(const GaussianMixture::shared_ptr &ptr) { template <class Conditional>
push_back(HybridConditional(ptr)); void emplace_back(Conditional *conditional) {
factors_.push_back(boost::make_shared<HybridConditional>(
boost::shared_ptr<Conditional>(conditional)));
} }
/// Add a Gaussian conditional to the Bayes Net. /// Add a conditional directly using a shared_ptr.
void addGaussian(const GaussianConditional::shared_ptr &ptr) { void push_back(boost::shared_ptr<HybridConditional> conditional) {
push_back(HybridConditional(ptr)); factors_.push_back(conditional);
} }
/// Add a discrete conditional to the Bayes Net. /// Add a conditional directly using implicit conversion.
void addDiscrete(const DiscreteConditional::shared_ptr &ptr) { void push_back(HybridConditional &&conditional) {
push_back(HybridConditional(ptr)); factors_.push_back(
boost::make_shared<HybridConditional>(std::move(conditional)));
} }
/// Add a Gaussian Mixture to the Bayes Net.
template <typename... T>
void emplaceMixture(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianMixture>(std::forward<T>(args)...)));
}
/// Add a Gaussian conditional to the Bayes Net.
template <typename... T>
void emplaceGaussian(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianConditional>(std::forward<T>(args)...)));
}
/// Add a discrete conditional to the Bayes Net.
template <typename... T>
void emplaceDiscrete(T &&...args) {
push_back(HybridConditional(
boost::make_shared<DiscreteConditional>(std::forward<T>(args)...)));
}
using Base::push_back;
/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atMixture(size_t i) const;
/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;
/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
/** /**
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment. * value assignment.

View File

@ -39,7 +39,7 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional) const boost::shared_ptr<GaussianConditional> &continuousConditional)
: HybridConditional(continuousConditional->keys(), {}, : HybridConditional(continuousConditional->keys(), {},
continuousConditional->nrFrontals()) { continuousConditional->nrFrontals()) {
inner_ = continuousConditional; inner_ = continuousConditional;
@ -47,7 +47,7 @@ HybridConditional::HybridConditional(
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<DiscreteConditional> discreteConditional) const boost::shared_ptr<DiscreteConditional> &discreteConditional)
: HybridConditional({}, discreteConditional->discreteKeys(), : HybridConditional({}, discreteConditional->discreteKeys(),
discreteConditional->nrFrontals()) { discreteConditional->nrFrontals()) {
inner_ = discreteConditional; inner_ = discreteConditional;
@ -55,7 +55,7 @@ HybridConditional::HybridConditional(
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture) const boost::shared_ptr<GaussianMixture> &gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(), : BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() + gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous()), gaussianMixture->nrContinuous()),

View File

@ -111,7 +111,7 @@ class GTSAM_EXPORT HybridConditional
* HybridConditional. * HybridConditional.
*/ */
HybridConditional( HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional); const boost::shared_ptr<GaussianConditional>& continuousConditional);
/** /**
* @brief Construct a new Hybrid Conditional object * @brief Construct a new Hybrid Conditional object
@ -119,7 +119,8 @@ class GTSAM_EXPORT HybridConditional
* @param discreteConditional Conditional used to create the * @param discreteConditional Conditional used to create the
* HybridConditional. * HybridConditional.
*/ */
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional); HybridConditional(
const boost::shared_ptr<DiscreteConditional>& discreteConditional);
/** /**
* @brief Construct a new Hybrid Conditional object * @brief Construct a new Hybrid Conditional object
@ -127,7 +128,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the * @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional. * HybridConditional.
*/ */
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); HybridConditional(const boost::shared_ptr<GaussianMixture>& gaussianMixture);
/// @} /// @}
/// @name Testable /// @name Testable

View File

@ -46,7 +46,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
} }
// Add the partial bayes net to the posterior bayes net. // Add the partial bayes net to the posterior bayes net.
hybridBayesNet_.push_back<HybridBayesNet>(*bayesNetFragment); hybridBayesNet_.add(*bayesNetFragment);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -100,7 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
/* ************************************************************************* */ /* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture( GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const { size_t index) const {
return hybridBayesNet_.atMixture(index); return hybridBayesNet_.at(index)->asMixture();
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -135,29 +135,9 @@ class HybridBayesTree {
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
class HybridBayesNet { class HybridBayesNet {
HybridBayesNet(); HybridBayesNet();
void add(const gtsam::HybridConditional& s); void push_back(const gtsam::GaussianMixture* s);
void addMixture(const gtsam::GaussianMixture* s); void push_back(const gtsam::GaussianConditional* s);
void addGaussian(const gtsam::GaussianConditional* s); void push_back(const gtsam::DiscreteConditional* s);
void addDiscrete(const gtsam::DiscreteConditional* s);
void emplaceMixture(const gtsam::GaussianMixture& s);
void emplaceMixture(const gtsam::KeyVector& continuousFrontals,
const gtsam::KeyVector& continuousParents,
const gtsam::DiscreteKeys& discreteParents,
const std::vector<gtsam::GaussianConditional::shared_ptr>&
conditionalsList);
void emplaceGaussian(const gtsam::GaussianConditional& s);
void emplaceDiscrete(const gtsam::DiscreteConditional& s);
void emplaceDiscrete(const gtsam::DiscreteKey& key, string spec);
void emplaceDiscrete(const gtsam::DiscreteKey& key,
const gtsam::DiscreteKeys& parents, string spec);
void emplaceDiscrete(const gtsam::DiscreteKey& key,
const std::vector<gtsam::DiscreteKey>& parents,
string spec);
gtsam::GaussianMixture* atMixture(size_t i) const;
gtsam::GaussianConditional* atGaussian(size_t i) const;
gtsam::DiscreteConditional* atDiscrete(size_t i) const;
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;

View File

@ -43,22 +43,22 @@ inline HybridBayesNet createHybridBayesNet(int num_measurements = 1,
// Create Gaussian mixture z_i = x0 + noise for each measurement. // Create Gaussian mixture z_i = x0 + noise for each measurement.
for (int i = 0; i < num_measurements; i++) { for (int i = 0; i < num_measurements; i++) {
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode; const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
GaussianMixture gm({Z(i)}, {X(0)}, {mode_i}, bayesNet.emplace_back(
{GaussianConditional::sharedMeanAndStddev( new GaussianMixture({Z(i)}, {X(0)}, {mode_i},
Z(i), I_1x1, X(0), Z_1x1, 0.5), {GaussianConditional::sharedMeanAndStddev(
GaussianConditional::sharedMeanAndStddev( Z(i), I_1x1, X(0), Z_1x1, 0.5),
Z(i), I_1x1, X(0), Z_1x1, 3)}); GaussianConditional::sharedMeanAndStddev(
bayesNet.emplaceMixture(gm); // copy :-( Z(i), I_1x1, X(0), Z_1x1, 3)}));
} }
// Create prior on X(0). // Create prior on X(0).
bayesNet.addGaussian( bayesNet.push_back(
GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5)); GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));
// Add prior on mode. // Add prior on mode.
const size_t nrModes = manyModes ? num_measurements : 1; const size_t nrModes = manyModes ? num_measurements : 1;
for (int i = 0; i < nrModes; i++) { for (int i = 0; i < nrModes; i++) {
bayesNet.emplaceDiscrete(DiscreteKey{M(i), 2}, "4/6"); bayesNet.emplace_back(new DiscreteConditional({M(i), 2}, "4/6"));
} }
return bayesNet; return bayesNet;
} }

View File

@ -42,21 +42,21 @@ static const DiscreteKey Asia(asiaKey, 2);
// Test creation of a pure discrete Bayes net. // Test creation of a pure discrete Bayes net.
TEST(HybridBayesNet, Creation) { TEST(HybridBayesNet, Creation) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1"); bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
DiscreteConditional expected(Asia, "99/1"); DiscreteConditional expected(Asia, "99/1");
CHECK(bayesNet.atDiscrete(0)); CHECK(bayesNet.at(0)->asDiscrete());
EXPECT(assert_equal(expected, *bayesNet.atDiscrete(0))); EXPECT(assert_equal(expected, *bayesNet.at(0)->asDiscrete()));
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Test adding a Bayes net to another one. // Test adding a Bayes net to another one.
TEST(HybridBayesNet, Add) { TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1"); bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
HybridBayesNet other; HybridBayesNet other;
other.push_back(bayesNet); other.add(bayesNet);
EXPECT(bayesNet.equals(other)); EXPECT(bayesNet.equals(other));
} }
@ -64,7 +64,7 @@ TEST(HybridBayesNet, Add) {
// Test evaluate for a pure discrete Bayes net P(Asia). // Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1"); bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
HybridValues values; HybridValues values;
values.insert(asiaKey, 0); values.insert(asiaKey, 0);
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
@ -80,7 +80,7 @@ TEST(HybridBayesNet, Tiny) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). // Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, evaluateHybrid) { TEST(HybridBayesNet, evaluateHybrid) {
const auto continuousConditional = GaussianConditional::FromMeanAndStddev( const auto continuousConditional = GaussianConditional::sharedMeanAndStddev(
X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0); X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0);
const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)), const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)),
@ -93,10 +93,11 @@ TEST(HybridBayesNet, evaluateHybrid) {
// Create hybrid Bayes net. // Create hybrid Bayes net.
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplaceGaussian(continuousConditional); bayesNet.push_back(GaussianConditional::sharedMeanAndStddev(
GaussianMixture gm({X(1)}, {}, {Asia}, {conditional0, conditional1}); X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0));
bayesNet.emplaceMixture(gm); // copy :-( bayesNet.emplace_back(
bayesNet.emplaceDiscrete(Asia, "99/1"); new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}));
bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1"));
// Create values at which to evaluate. // Create values at which to evaluate.
HybridValues values; HybridValues values;
@ -105,7 +106,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
values.insert(X(1), Vector1(1)); values.insert(X(1), Vector1(1));
const double conditionalProbability = const double conditionalProbability =
continuousConditional.evaluate(values.continuous()); continuousConditional->evaluate(values.continuous());
const double mixtureProbability = conditional0->evaluate(values.continuous()); const double mixtureProbability = conditional0->evaluate(values.continuous());
EXPECT_DOUBLES_EQUAL(conditionalProbability * mixtureProbability * 0.99, EXPECT_DOUBLES_EQUAL(conditionalProbability * mixtureProbability * 0.99,
bayesNet.evaluate(values), 1e-9); bayesNet.evaluate(values), 1e-9);
@ -135,17 +136,13 @@ TEST(HybridBayesNet, Choose) {
EXPECT_LONGS_EQUAL(4, gbn.size()); EXPECT_LONGS_EQUAL(4, gbn.size());
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*hybridBayesNet->at(0)->asMixture())(assignment),
hybridBayesNet->atMixture(0)))(assignment),
*gbn.at(0))); *gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*hybridBayesNet->at(1)->asMixture())(assignment),
hybridBayesNet->atMixture(1)))(assignment),
*gbn.at(1))); *gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*hybridBayesNet->at(2)->asMixture())(assignment),
hybridBayesNet->atMixture(2)))(assignment),
*gbn.at(2))); *gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>( EXPECT(assert_equal(*(*hybridBayesNet->at(3)->asMixture())(assignment),
hybridBayesNet->atMixture(3)))(assignment),
*gbn.at(3))); *gbn.at(3)));
} }
@ -247,11 +244,12 @@ TEST(HybridBayesNet, Error) {
double total_error = 0; double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) { if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error( double error = hybridBayesNet->at(idx)->asMixture()->error(
{delta.continuous(), discrete_values}); {delta.continuous(), discrete_values});
total_error += error; total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) { } else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); double error =
hybridBayesNet->at(idx)->asGaussian()->error(delta.continuous());
total_error += error; total_error += error;
} }
} }

View File

@ -310,7 +310,7 @@ TEST(HybridEstimation, Probability) {
for (auto discrete_conditional : *discreteBayesNet) { for (auto discrete_conditional : *discreteBayesNet) {
bayesNet->add(discrete_conditional); bayesNet->add(discrete_conditional);
} }
auto discreteConditional = discreteBayesNet->atDiscrete(0); auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();
HybridValues hybrid_values = bayesNet->optimize(); HybridValues hybrid_values = bayesNet->optimize();

View File

@ -677,11 +677,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
X(0), Vector1(14.1421), I_1x1 * 2.82843), X(0), Vector1(14.1421), I_1x1 * 2.82843),
conditional1 = boost::make_shared<GaussianConditional>( conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.1379), I_1x1 * 2.02759); X(0), Vector1(10.1379), I_1x1 * 2.02759);
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1}); expectedBayesNet.emplace_back(
expectedBayesNet.emplaceMixture(gm); // copy :-( new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1}));
// Add prior on mode. // Add prior on mode.
expectedBayesNet.emplaceDiscrete(mode, "74/26"); expectedBayesNet.emplace_back(new DiscreteConditional(mode, "74/26"));
// Test elimination // Test elimination
Ordering ordering; Ordering ordering;
@ -712,11 +712,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
X(0), Vector1(17.3205), I_1x1 * 3.4641), X(0), Vector1(17.3205), I_1x1 * 3.4641),
conditional1 = boost::make_shared<GaussianConditional>( conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.274), I_1x1 * 2.0548); X(0), Vector1(10.274), I_1x1 * 2.0548);
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1}); expectedBayesNet.emplace_back(
expectedBayesNet.emplaceMixture(gm); // copy :-( new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1}));
// Add prior on mode. // Add prior on mode.
expectedBayesNet.emplaceDiscrete(mode, "23/77"); expectedBayesNet.emplace_back(new DiscreteConditional(mode, "23/77"));
// Test elimination // Test elimination
Ordering ordering; Ordering ordering;
@ -764,13 +764,10 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
// regression // regression
EXPECT_DOUBLES_EQUAL(0.018253037966018862, expected_ratio, 1e-6); EXPECT_DOUBLES_EQUAL(0.018253037966018862, expected_ratio, 1e-6);
// 3. Do sampling // Test ratios for a number of independent samples:
constexpr int num_samples = 100; constexpr int num_samples = 100;
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
// Sample from the bayes net
HybridValues sample = bn.sample(&rng); HybridValues sample = bn.sample(&rng);
// Check that the ratio is constant.
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6); EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
} }
} }
@ -787,34 +784,34 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
for (size_t t : {0, 1, 2}) { for (size_t t : {0, 1, 2}) {
// Create Gaussian mixture on Z(t) conditioned on X(t) and mode N(t): // Create Gaussian mixture on Z(t) conditioned on X(t) and mode N(t):
const auto noise_mode_t = DiscreteKey{N(t), 2}; const auto noise_mode_t = DiscreteKey{N(t), 2};
GaussianMixture gm({Z(t)}, {X(t)}, {noise_mode_t}, bn.emplace_back(
{GaussianConditional::sharedMeanAndStddev( new GaussianMixture({Z(t)}, {X(t)}, {noise_mode_t},
Z(t), I_1x1, X(t), Z_1x1, 0.5), {GaussianConditional::sharedMeanAndStddev(
GaussianConditional::sharedMeanAndStddev( Z(t), I_1x1, X(t), Z_1x1, 0.5),
Z(t), I_1x1, X(t), Z_1x1, 3.0)}); GaussianConditional::sharedMeanAndStddev(
bn.emplaceMixture(gm); // copy :-( Z(t), I_1x1, X(t), Z_1x1, 3.0)}));
// Create prior on discrete mode M(t): // Create prior on discrete mode M(t):
bn.emplaceDiscrete(noise_mode_t, "20/80"); bn.emplace_back(new DiscreteConditional(noise_mode_t, "20/80"));
} }
// Add motion models: // Add motion models:
for (size_t t : {2, 1}) { for (size_t t : {2, 1}) {
// Create Gaussian mixture on X(t) conditioned on X(t-1) and mode M(t-1): // Create Gaussian mixture on X(t) conditioned on X(t-1) and mode M(t-1):
const auto motion_model_t = DiscreteKey{M(t), 2}; const auto motion_model_t = DiscreteKey{M(t), 2};
GaussianMixture gm({X(t)}, {X(t - 1)}, {motion_model_t}, bn.emplace_back(
{GaussianConditional::sharedMeanAndStddev( new GaussianMixture({X(t)}, {X(t - 1)}, {motion_model_t},
X(t), I_1x1, X(t - 1), Z_1x1, 0.2), {GaussianConditional::sharedMeanAndStddev(
GaussianConditional::sharedMeanAndStddev( X(t), I_1x1, X(t - 1), Z_1x1, 0.2),
X(t), I_1x1, X(t - 1), I_1x1, 0.2)}); GaussianConditional::sharedMeanAndStddev(
bn.emplaceMixture(gm); // copy :-( X(t), I_1x1, X(t - 1), I_1x1, 0.2)}));
// Create prior on motion model M(t): // Create prior on motion model M(t):
bn.emplaceDiscrete(motion_model_t, "40/60"); bn.emplace_back(new DiscreteConditional(motion_model_t, "40/60"));
} }
// Create Gaussian prior on continuous X(0) using sharedMeanAndStddev: // Create Gaussian prior on continuous X(0) using sharedMeanAndStddev:
bn.addGaussian(GaussianConditional::sharedMeanAndStddev(X(0), Z_1x1, 0.1)); bn.push_back(GaussianConditional::sharedMeanAndStddev(X(0), Z_1x1, 0.1));
// Make sure we an sample from the Bayes net: // Make sure we an sample from the Bayes net:
EXPECT_LONGS_EQUAL(6, bn.sample().continuous().size()); EXPECT_LONGS_EQUAL(6, bn.sample().continuous().size());
@ -822,7 +819,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// Create measurements consistent with moving right every time: // Create measurements consistent with moving right every time:
const VectorValues measurements{ const VectorValues measurements{
{Z(0), Vector1(0.0)}, {Z(1), Vector1(1.0)}, {Z(2), Vector1(2.0)}}; {Z(0), Vector1(0.0)}, {Z(1), Vector1(1.0)}, {Z(2), Vector1(2.0)}};
const auto fg = bn.toFactorGraph(measurements); const HybridGaussianFactorGraph fg = bn.toFactorGraph(measurements);
// Create ordering that eliminates in time order, then discrete modes: // Create ordering that eliminates in time order, then discrete modes:
Ordering ordering; Ordering ordering;
@ -835,11 +832,11 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
ordering.push_back(M(1)); ordering.push_back(M(1));
ordering.push_back(M(2)); ordering.push_back(M(2));
// Test elimination result has correct size: // Do elimination:
const auto posterior = fg.eliminateSequential(ordering); const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering);
// GTSAM_PRINT(*posterior); // GTSAM_PRINT(*posterior);
// Test elimination result has correct size: // Test resulting posterior Bayes net has correct size:
EXPECT_LONGS_EQUAL(8, posterior->size()); EXPECT_LONGS_EQUAL(8, posterior->size());
// TODO(dellaert): below is copy/pasta from above, refactor // TODO(dellaert): below is copy/pasta from above, refactor
@ -861,13 +858,10 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// regression // regression
EXPECT_DOUBLES_EQUAL(0.0094526745785019472, expected_ratio, 1e-6); EXPECT_DOUBLES_EQUAL(0.0094526745785019472, expected_ratio, 1e-6);
// 3. Do sampling // Test ratios for a number of independent samples:
constexpr int num_samples = 100; constexpr int num_samples = 100;
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
// Sample from the bayes net
HybridValues sample = bn.sample(&rng); HybridValues sample = bn.sample(&rng);
// Check that the ratio is constant.
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6); EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
} }
} }

View File

@ -16,13 +16,13 @@ import numpy as np
from gtsam.symbol_shorthand import A, X from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam from gtsam import (DiscreteKeys, GaussianMixture, DiscreteConditional, GaussianConditional, GaussianMixture,
from gtsam import (DiscreteKeys, GaussianConditional, GaussianMixture,
HybridBayesNet, HybridValues, noiseModel) HybridBayesNet, HybridValues, noiseModel)
class TestHybridBayesNet(GtsamTestCase): class TestHybridBayesNet(GtsamTestCase):
"""Unit tests for HybridValues.""" """Unit tests for HybridValues."""
def test_evaluate(self): def test_evaluate(self):
"""Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).""" """Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
asiaKey = A(0) asiaKey = A(0)
@ -40,15 +40,15 @@ class TestHybridBayesNet(GtsamTestCase):
# Create the conditionals # Create the conditionals
conditional0 = GaussianConditional(X(1), [5], I_1x1, model0) conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
conditional1 = GaussianConditional(X(1), [2], I_1x1, model1) conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
dkeys = DiscreteKeys() discrete_keys = DiscreteKeys()
dkeys.push_back(Asia) discrete_keys.push_back(Asia)
gm = GaussianMixture([X(1)], [], dkeys, [conditional0, conditional1])
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = HybridBayesNet() bayesNet = HybridBayesNet()
bayesNet.addGaussian(gc) bayesNet.push_back(gc)
bayesNet.addMixture(gm) bayesNet.push_back(GaussianMixture(
bayesNet.emplaceDiscrete(Asia, "99/1") [X(1)], [], discrete_keys, [conditional0, conditional1]))
bayesNet.push_back(DiscreteConditional(Asia, "99/1"))
# Create values at which to evaluate. # Create values at which to evaluate.
values = HybridValues() values = HybridValues()

View File

@ -108,16 +108,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
I_1x1, I_1x1,
X(0), [0], X(0), [0],
sigma=3) sigma=3)
bayesNet.emplaceMixture([Z(i)], [X(0)], keys, bayesNet.push_back(GaussianMixture([Z(i)], [X(0)], keys,
[conditional0, conditional1]) [conditional0, conditional1]))
# Create prior on X(0). # Create prior on X(0).
prior_on_x0 = GaussianConditional.FromMeanAndStddev( prior_on_x0 = GaussianConditional.FromMeanAndStddev(
X(0), [prior_mean], prior_sigma) X(0), [prior_mean], prior_sigma)
bayesNet.addGaussian(prior_on_x0) bayesNet.push_back(prior_on_x0)
# Add prior on mode. # Add prior on mode.
bayesNet.emplaceDiscrete(mode, "4/6") bayesNet.push_back(DiscreteConditional(mode, "4/6"))
return bayesNet return bayesNet
@ -163,11 +163,11 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
fg = HybridGaussianFactorGraph() fg = HybridGaussianFactorGraph()
num_measurements = bayesNet.size() - 2 num_measurements = bayesNet.size() - 2
for i in range(num_measurements): for i in range(num_measurements):
conditional = bayesNet.atMixture(i) conditional = bayesNet.at(i).asMixture()
factor = conditional.likelihood(cls.measurements(sample, [i])) factor = conditional.likelihood(cls.measurements(sample, [i]))
fg.push_back(factor) fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(num_measurements)) fg.push_back(bayesNet.at(num_measurements).asGaussian())
fg.push_back(bayesNet.atDiscrete(num_measurements+1)) fg.push_back(bayesNet.at(num_measurements+1).asDiscrete())
return fg return fg
@classmethod @classmethod