Fix python wrapper
parent
8b8cde4230
commit
d49bcce780
|
|
@ -135,29 +135,9 @@ class HybridBayesTree {
|
|||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
class HybridBayesNet {
|
||||
HybridBayesNet();
|
||||
void add(const gtsam::HybridConditional& s);
|
||||
void addMixture(const gtsam::GaussianMixture* s);
|
||||
void addGaussian(const gtsam::GaussianConditional* s);
|
||||
void addDiscrete(const gtsam::DiscreteConditional* s);
|
||||
|
||||
void emplaceMixture(const gtsam::GaussianMixture& s);
|
||||
void emplaceMixture(const gtsam::KeyVector& continuousFrontals,
|
||||
const gtsam::KeyVector& continuousParents,
|
||||
const gtsam::DiscreteKeys& discreteParents,
|
||||
const std::vector<gtsam::GaussianConditional::shared_ptr>&
|
||||
conditionalsList);
|
||||
void emplaceGaussian(const gtsam::GaussianConditional& s);
|
||||
void emplaceDiscrete(const gtsam::DiscreteConditional& s);
|
||||
void emplaceDiscrete(const gtsam::DiscreteKey& key, string spec);
|
||||
void emplaceDiscrete(const gtsam::DiscreteKey& key,
|
||||
const gtsam::DiscreteKeys& parents, string spec);
|
||||
void emplaceDiscrete(const gtsam::DiscreteKey& key,
|
||||
const std::vector<gtsam::DiscreteKey>& parents,
|
||||
string spec);
|
||||
|
||||
gtsam::GaussianMixture* atMixture(size_t i) const;
|
||||
gtsam::GaussianConditional* atGaussian(size_t i) const;
|
||||
gtsam::DiscreteConditional* atDiscrete(size_t i) const;
|
||||
void push_back(const gtsam::GaussianMixture* s);
|
||||
void push_back(const gtsam::GaussianConditional* s);
|
||||
void push_back(const gtsam::DiscreteConditional* s);
|
||||
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
|
|
|
|||
|
|
@ -16,13 +16,13 @@ import numpy as np
|
|||
from gtsam.symbol_shorthand import A, X
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
import gtsam
|
||||
from gtsam import (DiscreteKeys, GaussianConditional, GaussianMixture,
|
||||
from gtsam import (DiscreteKeys, GaussianMixture, DiscreteConditional, GaussianConditional, GaussianMixture,
|
||||
HybridBayesNet, HybridValues, noiseModel)
|
||||
|
||||
|
||||
class TestHybridBayesNet(GtsamTestCase):
|
||||
"""Unit tests for HybridValues."""
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia)."""
|
||||
asiaKey = A(0)
|
||||
|
|
@ -40,15 +40,15 @@ class TestHybridBayesNet(GtsamTestCase):
|
|||
# Create the conditionals
|
||||
conditional0 = GaussianConditional(X(1), [5], I_1x1, model0)
|
||||
conditional1 = GaussianConditional(X(1), [2], I_1x1, model1)
|
||||
dkeys = DiscreteKeys()
|
||||
dkeys.push_back(Asia)
|
||||
gm = GaussianMixture([X(1)], [], dkeys, [conditional0, conditional1])
|
||||
discrete_keys = DiscreteKeys()
|
||||
discrete_keys.push_back(Asia)
|
||||
|
||||
# Create hybrid Bayes net.
|
||||
bayesNet = HybridBayesNet()
|
||||
bayesNet.addGaussian(gc)
|
||||
bayesNet.addMixture(gm)
|
||||
bayesNet.emplaceDiscrete(Asia, "99/1")
|
||||
bayesNet.push_back(gc)
|
||||
bayesNet.push_back(GaussianMixture(
|
||||
[X(1)], [], discrete_keys, [conditional0, conditional1]))
|
||||
bayesNet.push_back(DiscreteConditional(Asia, "99/1"))
|
||||
|
||||
# Create values at which to evaluate.
|
||||
values = HybridValues()
|
||||
|
|
|
|||
|
|
@ -108,16 +108,16 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
I_1x1,
|
||||
X(0), [0],
|
||||
sigma=3)
|
||||
bayesNet.emplaceMixture([Z(i)], [X(0)], keys,
|
||||
[conditional0, conditional1])
|
||||
bayesNet.push_back(GaussianMixture([Z(i)], [X(0)], keys,
|
||||
[conditional0, conditional1]))
|
||||
|
||||
# Create prior on X(0).
|
||||
prior_on_x0 = GaussianConditional.FromMeanAndStddev(
|
||||
X(0), [prior_mean], prior_sigma)
|
||||
bayesNet.addGaussian(prior_on_x0)
|
||||
bayesNet.push_back(prior_on_x0)
|
||||
|
||||
# Add prior on mode.
|
||||
bayesNet.emplaceDiscrete(mode, "4/6")
|
||||
bayesNet.push_back(DiscreteConditional(mode, "4/6"))
|
||||
|
||||
return bayesNet
|
||||
|
||||
|
|
@ -163,11 +163,11 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
fg = HybridGaussianFactorGraph()
|
||||
num_measurements = bayesNet.size() - 2
|
||||
for i in range(num_measurements):
|
||||
conditional = bayesNet.atMixture(i)
|
||||
conditional = bayesNet.at(i).asMixture()
|
||||
factor = conditional.likelihood(cls.measurements(sample, [i]))
|
||||
fg.push_back(factor)
|
||||
fg.push_back(bayesNet.atGaussian(num_measurements))
|
||||
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||
fg.push_back(bayesNet.at(num_measurements).asGaussian())
|
||||
fg.push_back(bayesNet.at(num_measurements+1).asDiscrete())
|
||||
return fg
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue