From 87ff4af32dcbeaa97f536da9de2931ad43483289 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 10 Jan 2023 23:13:40 -0800 Subject: [PATCH] Wrapper and tests for logProbability --- gtsam/discrete/discrete.i | 8 ++++++++ gtsam/hybrid/hybrid.i | 8 +++++++- gtsam/linear/linear.i | 3 +++ python/gtsam/tests/test_DiscreteBayesNet.py | 18 +++++++++++++----- python/gtsam/tests/test_GaussianBayesNet.py | 18 ++++++++++++++++-- python/gtsam/tests/test_HybridBayesNet.py | 17 +++++++++++------ 6 files changed, 58 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index fa98f36fa..a25897ffa 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -95,6 +95,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + double logProbability(const gtsam::DiscreteValues& values) const; + double evaluate(const gtsam::DiscreteValues& values) const; + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteConditional operator*( const gtsam::DiscreteConditional& other) const; gtsam::DiscreteConditional marginal(gtsam::Key key) const; @@ -157,7 +160,12 @@ class DiscreteBayesNet { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; + + // Standard interface. + double logProbability(const gtsam::DiscreteValues& values) const; + double evaluate(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 012f707e4..aad1cca9b 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -61,6 +61,9 @@ virtual class HybridConditional { size_t nrParents() const; // Standard interface: + double logProbability(const gtsam::HybridValues& values) const; + double evaluate(const gtsam::HybridValues& values) const; + double operator()(const gtsam::HybridValues& values) const; gtsam::GaussianMixture* asMixture() const; gtsam::GaussianConditional* asGaussian() const; gtsam::DiscreteConditional* asDiscrete() const; @@ -133,7 +136,10 @@ class HybridBayesNet { gtsam::KeySet keys() const; const gtsam::HybridConditional* at(size_t i) const; - double evaluate(const gtsam::HybridValues& x) const; + // Standard interface: + double logProbability(const gtsam::HybridValues& values) const; + double evaluate(const gtsam::HybridValues& values) const; + gtsam::HybridValues optimize() const; gtsam::HybridValues sample(const gtsam::HybridValues &given) const; gtsam::HybridValues sample() const; diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index 41bce61d1..2d88c5f93 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -497,6 +497,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor { bool equals(const gtsam::GaussianConditional& cg, double tol) const; // Standard Interface + double logProbability(const gtsam::VectorValues& x) const; double evaluate(const gtsam::VectorValues& x) const; double error(const gtsam::VectorValues& x) const; gtsam::Key firstFrontalKey() const; @@ -558,6 +559,8 @@ virtual class GaussianBayesNet { gtsam::GaussianConditional* back() const; // Standard interface + // Standard Interface + double logProbability(const gtsam::VectorValues& x) const; double evaluate(const gtsam::VectorValues& x) const; double error(const gtsam::VectorValues& x) const; diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index ff2ba99d1..d597effa8 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -11,13 +11,15 @@ Author: Frank Dellaert # pylint: disable=no-name-in-module, invalid-name +import math import textwrap import unittest +from gtsam.utils.test_case import GtsamTestCase + import gtsam from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteDistribution, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering) -from gtsam.utils.test_case import GtsamTestCase # Some keys: Asia = (0, 2) @@ -111,7 +113,7 @@ class TestDiscreteBayesNet(GtsamTestCase): self.assertEqual(len(actualSample), 8) def test_fragment(self): - """Test sampling and optimizing for Asia fragment.""" + """Test evaluate/sampling/optimizing for Asia fragment.""" # Create a reverse-topologically sorted fragment: fragment = DiscreteBayesNet() @@ -125,8 +127,14 @@ class TestDiscreteBayesNet(GtsamTestCase): given[key[0]] = 0 # Now sample from fragment: - actual = fragment.sample(given) - self.assertEqual(len(actual), 5) + values = fragment.sample(given) + self.assertEqual(len(values), 5) + + for i in [0, 1, 2]: + self.assertAlmostEqual(fragment.at(i).logProbability(values), + math.log(fragment.at(i).evaluate(values))) + self.assertAlmostEqual(fragment.logProbability(values), + math.log(fragment.evaluate(values))) def test_dot(self): """Check that dot works with position hints.""" @@ -139,7 +147,7 @@ class TestDiscreteBayesNet(GtsamTestCase): # Make sure we can *update* position hints writer = gtsam.DotWriter() ph: dict = writer.positionHints - ph['a'] = 2 # hint at symbol position + ph['a'] = 2 # hint at symbol position writer.positionHints = ph # Check the output of dot diff --git a/python/gtsam/tests/test_GaussianBayesNet.py b/python/gtsam/tests/test_GaussianBayesNet.py index 022de8c3f..9065c7bee 100644 --- a/python/gtsam/tests/test_GaussianBayesNet.py +++ b/python/gtsam/tests/test_GaussianBayesNet.py @@ -12,13 +12,15 @@ Author: Frank Dellaert from __future__ import print_function +import math import unittest -import gtsam import numpy as np -from gtsam import GaussianBayesNet, GaussianConditional from gtsam.utils.test_case import GtsamTestCase +import gtsam +from gtsam import GaussianBayesNet, GaussianConditional + # some keys _x_ = 11 _y_ = 22 @@ -45,6 +47,18 @@ class TestGaussianBayesNet(GtsamTestCase): np.testing.assert_equal(R, R1) np.testing.assert_equal(d, d1) + def test_evaluate(self): + """Test evaluate method""" + bayesNet = smallBayesNet() + values = gtsam.VectorValues() + values.insert(_x_, np.array([9.0])) + values.insert(_y_, np.array([5.0])) + for i in [0, 1]: + self.assertAlmostEqual(bayesNet.at(i).logProbability(values), + math.log(bayesNet.at(i).evaluate(values))) + self.assertAlmostEqual(bayesNet.logProbability(values), + math.log(bayesNet.evaluate(values))) + def test_sample(self): """Test sample method""" bayesNet = smallBayesNet() diff --git a/python/gtsam/tests/test_HybridBayesNet.py b/python/gtsam/tests/test_HybridBayesNet.py index 75a2e9f8b..c949551c4 100644 --- a/python/gtsam/tests/test_HybridBayesNet.py +++ b/python/gtsam/tests/test_HybridBayesNet.py @@ -10,14 +10,15 @@ Author: Frank Dellaert """ # pylint: disable=invalid-name, no-name-in-module, no-member +import math import unittest import numpy as np from gtsam.symbol_shorthand import A, X from gtsam.utils.test_case import GtsamTestCase -from gtsam import (DiscreteKeys, GaussianMixture, DiscreteConditional, GaussianConditional, GaussianMixture, - HybridBayesNet, HybridValues, noiseModel) +from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, + GaussianMixture, HybridBayesNet, HybridValues, noiseModel) class TestHybridBayesNet(GtsamTestCase): @@ -30,8 +31,8 @@ class TestHybridBayesNet(GtsamTestCase): # Create the continuous conditional I_1x1 = np.eye(1) - gc = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], - 5.0) + conditional = GaussianConditional.FromMeanAndStddev(X(0), 2 * I_1x1, X(1), [-4], + 5.0) # Create the noise models model0 = noiseModel.Diagonal.Sigmas([2.0]) @@ -45,7 +46,7 @@ class TestHybridBayesNet(GtsamTestCase): # Create hybrid Bayes net. bayesNet = HybridBayesNet() - bayesNet.push_back(gc) + bayesNet.push_back(conditional) bayesNet.push_back(GaussianMixture( [X(1)], [], discrete_keys, [conditional0, conditional1])) bayesNet.push_back(DiscreteConditional(Asia, "99/1")) @@ -56,13 +57,17 @@ class TestHybridBayesNet(GtsamTestCase): values.insert(X(0), [-6]) values.insert(X(1), [1]) - conditionalProbability = gc.evaluate(values.continuous()) + conditionalProbability = conditional.evaluate(values.continuous()) mixtureProbability = conditional0.evaluate(values.continuous()) self.assertAlmostEqual(conditionalProbability * mixtureProbability * 0.99, bayesNet.evaluate(values), places=5) + # Check logProbability + self.assertAlmostEqual(bayesNet.logProbability(values), + math.log(bayesNet.evaluate(values))) + if __name__ == "__main__": unittest.main()