Make sure all conditional methods can be called in wrappers and satisfy invariants there, as well.

release/4.3a0
Frank Dellaert 2023-01-14 13:24:54 -08:00
parent bead5ce4da
commit 51c46410dc
4 changed files with 49 additions and 6 deletions

View File

@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
};
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/HybridValues.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);
// Standard interface
double logNormalizationConstant() const;
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
double error(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(size_t value) const;
size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
// Markdown and HTML
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;
// Expose HybridValues versions
double logProbability(const gtsam::HybridValues& x) const;
double evaluate(const gtsam::HybridValues& x) const;
double error(const gtsam::HybridValues& x) const;
};
#include <gtsam/discrete/DiscreteDistribution.h>

View File

@ -61,6 +61,7 @@ virtual class HybridConditional {
size_t nrParents() const;
// Standard interface:
double logNormalizationConstant() const;
double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const;
double operator()(const gtsam::HybridValues& values) const;

View File

@ -456,6 +456,7 @@ class GaussianFactorGraph {
};
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/hybrid/HybridValues.h>
virtual class GaussianConditional : gtsam::JacobianFactor {
// Constructors
GaussianConditional(size_t key, Vector d, Matrix R,
@ -497,6 +498,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
// Standard Interface
double logNormalizationConstant() const;
double logProbability(const gtsam::VectorValues& x) const;
double evaluate(const gtsam::VectorValues& x) const;
double error(const gtsam::VectorValues& x) const;
@ -518,6 +520,11 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
// enabling serialization functionality
void serialize() const;
// Expose HybridValues versions
double logProbability(const gtsam::HybridValues& x) const;
double evaluate(const gtsam::HybridValues& x) const;
double error(const gtsam::HybridValues& x) const;
};
#include <gtsam/linear/GaussianDensity.h>

View File

@ -17,8 +17,8 @@ import numpy as np
from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, HybridBayesNet, HybridValues, noiseModel)
from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, GaussianConditional,
GaussianMixture, HybridBayesNet, HybridValues, noiseModel, VectorValues)
class TestHybridBayesNet(GtsamTestCase):
@ -53,9 +53,13 @@ class TestHybridBayesNet(GtsamTestCase):
# Create values at which to evaluate.
values = HybridValues()
values.insert(asiaKey, 0)
values.insert(X(0), [-6])
values.insert(X(1), [1])
continuous = VectorValues()
continuous.insert(X(0), [-6])
continuous.insert(X(1), [1])
values.insert(continuous)
discrete = DiscreteValues()
discrete[asiaKey] = 0
values.insert(discrete)
conditionalProbability = conditional.evaluate(values.continuous())
mixtureProbability = conditional0.evaluate(values.continuous())
@ -68,6 +72,26 @@ class TestHybridBayesNet(GtsamTestCase):
self.assertAlmostEqual(bayesNet.logProbability(values),
math.log(bayesNet.evaluate(values)))
# Check invariance for all conditionals:
self.check_invariance(bayesNet.at(0).asGaussian(), continuous)
self.check_invariance(bayesNet.at(0).asGaussian(), values)
self.check_invariance(bayesNet.at(0), values)
self.check_invariance(bayesNet.at(1), values)
self.check_invariance(bayesNet.at(2).asDiscrete(), discrete)
self.check_invariance(bayesNet.at(2).asDiscrete(), values)
self.check_invariance(bayesNet.at(2), values)
def check_invariance(self, conditional, values):
"""Check invariance for given conditional."""
probability = conditional.evaluate(values)
self.assertTrue(probability >= 0.0)
logProb = conditional.logProbability(values)
self.assertAlmostEqual(probability, np.exp(logProb))
expected = conditional.logNormalizationConstant() - conditional.error(values)
self.assertAlmostEqual(logProb, expected)
if __name__ == "__main__":
unittest.main()