diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index a25897ffa..78efd57e2 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { }; #include +#include virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); @@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + + // Standard interface + double logNormalizationConstant() const; double logProbability(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const; - double operator()(const gtsam::DiscreteValues& values) const; + double error(const gtsam::DiscreteValues& values) const; gtsam::DiscreteConditional operator*( const gtsam::DiscreteConditional& other) const; gtsam::DiscreteConditional marginal(gtsam::Key key) const; @@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t sample(size_t value) const; size_t sample() const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + + // Markdown and HTML string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; string markdown(const gtsam::KeyFormatter& keyFormatter, @@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { gtsam::DefaultKeyFormatter) const; string html(const gtsam::KeyFormatter& keyFormatter, std::map> names) const; + + // Expose HybridValues versions + double logProbability(const gtsam::HybridValues& x) const; + double evaluate(const gtsam::HybridValues& x) const; + double error(const gtsam::HybridValues& x) const; }; #include diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index aad1cca9b..bbadd1aa8 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -61,6 +61,7 @@ virtual class HybridConditional { size_t nrParents() const; // Standard interface: + double logNormalizationConstant() const; double logProbability(const gtsam::HybridValues& values) const; double evaluate(const gtsam::HybridValues& values) const; double operator()(const gtsam::HybridValues& values) const; diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index 2d88c5f93..c0230f1c2 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -456,6 +456,7 @@ class GaussianFactorGraph { }; #include +#include virtual class GaussianConditional : gtsam::JacobianFactor { // Constructors GaussianConditional(size_t key, Vector d, Matrix R, @@ -497,6 +498,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor { bool equals(const gtsam::GaussianConditional& cg, double tol) const; // Standard Interface + double logNormalizationConstant() const; double logProbability(const gtsam::VectorValues& x) const; double evaluate(const gtsam::VectorValues& x) const; double error(const gtsam::VectorValues& x) const; @@ -518,6 +520,11 @@ virtual class GaussianConditional : gtsam::JacobianFactor { // enabling serialization functionality void serialize() const; + + // Expose HybridValues versions + double logProbability(const gtsam::HybridValues& x) const; + double evaluate(const gtsam::HybridValues& x) const; + double error(const gtsam::HybridValues& x) const; }; #include diff --git a/python/gtsam/tests/test_HybridBayesNet.py b/python/gtsam/tests/test_HybridBayesNet.py index c949551c4..01e1c5a5d 100644 --- a/python/gtsam/tests/test_HybridBayesNet.py +++ b/python/gtsam/tests/test_HybridBayesNet.py @@ -17,8 +17,8 @@ import numpy as np from gtsam.symbol_shorthand import A, X from gtsam.utils.test_case import GtsamTestCase -from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, - GaussianMixture, HybridBayesNet, HybridValues, noiseModel) +from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, GaussianConditional, + GaussianMixture, HybridBayesNet, HybridValues, noiseModel, VectorValues) class TestHybridBayesNet(GtsamTestCase): @@ -53,9 +53,13 @@ class TestHybridBayesNet(GtsamTestCase): # Create values at which to evaluate. values = HybridValues() - values.insert(asiaKey, 0) - values.insert(X(0), [-6]) - values.insert(X(1), [1]) + continuous = VectorValues() + continuous.insert(X(0), [-6]) + continuous.insert(X(1), [1]) + values.insert(continuous) + discrete = DiscreteValues() + discrete[asiaKey] = 0 + values.insert(discrete) conditionalProbability = conditional.evaluate(values.continuous()) mixtureProbability = conditional0.evaluate(values.continuous()) @@ -68,6 +72,26 @@ class TestHybridBayesNet(GtsamTestCase): self.assertAlmostEqual(bayesNet.logProbability(values), math.log(bayesNet.evaluate(values))) + # Check invariance for all conditionals: + self.check_invariance(bayesNet.at(0).asGaussian(), continuous) + self.check_invariance(bayesNet.at(0).asGaussian(), values) + self.check_invariance(bayesNet.at(0), values) + + self.check_invariance(bayesNet.at(1), values) + + self.check_invariance(bayesNet.at(2).asDiscrete(), discrete) + self.check_invariance(bayesNet.at(2).asDiscrete(), values) + self.check_invariance(bayesNet.at(2), values) + + def check_invariance(self, conditional, values): + """Check invariance for given conditional.""" + probability = conditional.evaluate(values) + self.assertTrue(probability >= 0.0) + logProb = conditional.logProbability(values) + self.assertAlmostEqual(probability, np.exp(logProb)) + expected = conditional.logNormalizationConstant() - conditional.error(values) + self.assertAlmostEqual(logProb, expected) + if __name__ == "__main__": unittest.main()