Make sure all conditional methods can be called in wrappers and satisfy invariants there, as well.
parent
bead5ce4da
commit
51c46410dc
|
@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional();
|
DiscreteConditional();
|
||||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||||
|
@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||||
const gtsam::DecisionTreeFactor& marginal,
|
const gtsam::DecisionTreeFactor& marginal,
|
||||||
const gtsam::Ordering& orderedKeys);
|
const gtsam::Ordering& orderedKeys);
|
||||||
|
|
||||||
|
// Standard interface
|
||||||
|
double logNormalizationConstant() const;
|
||||||
double logProbability(const gtsam::DiscreteValues& values) const;
|
double logProbability(const gtsam::DiscreteValues& values) const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double error(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteConditional operator*(
|
gtsam::DiscreteConditional operator*(
|
||||||
const gtsam::DiscreteConditional& other) const;
|
const gtsam::DiscreteConditional& other) const;
|
||||||
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
||||||
|
@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
|
|
||||||
|
// Markdown and HTML
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
|
|
||||||
|
// Expose HybridValues versions
|
||||||
|
double logProbability(const gtsam::HybridValues& x) const;
|
||||||
|
double evaluate(const gtsam::HybridValues& x) const;
|
||||||
|
double error(const gtsam::HybridValues& x) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
|
|
@ -61,6 +61,7 @@ virtual class HybridConditional {
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
|
|
||||||
// Standard interface:
|
// Standard interface:
|
||||||
|
double logNormalizationConstant() const;
|
||||||
double logProbability(const gtsam::HybridValues& values) const;
|
double logProbability(const gtsam::HybridValues& values) const;
|
||||||
double evaluate(const gtsam::HybridValues& values) const;
|
double evaluate(const gtsam::HybridValues& values) const;
|
||||||
double operator()(const gtsam::HybridValues& values) const;
|
double operator()(const gtsam::HybridValues& values) const;
|
||||||
|
|
|
@ -456,6 +456,7 @@ class GaussianFactorGraph {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
virtual class GaussianConditional : gtsam::JacobianFactor {
|
virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
// Constructors
|
// Constructors
|
||||||
GaussianConditional(size_t key, Vector d, Matrix R,
|
GaussianConditional(size_t key, Vector d, Matrix R,
|
||||||
|
@ -497,6 +498,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||||
|
|
||||||
// Standard Interface
|
// Standard Interface
|
||||||
|
double logNormalizationConstant() const;
|
||||||
double logProbability(const gtsam::VectorValues& x) const;
|
double logProbability(const gtsam::VectorValues& x) const;
|
||||||
double evaluate(const gtsam::VectorValues& x) const;
|
double evaluate(const gtsam::VectorValues& x) const;
|
||||||
double error(const gtsam::VectorValues& x) const;
|
double error(const gtsam::VectorValues& x) const;
|
||||||
|
@ -518,6 +520,11 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
void serialize() const;
|
void serialize() const;
|
||||||
|
|
||||||
|
// Expose HybridValues versions
|
||||||
|
double logProbability(const gtsam::HybridValues& x) const;
|
||||||
|
double evaluate(const gtsam::HybridValues& x) const;
|
||||||
|
double error(const gtsam::HybridValues& x) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianDensity.h>
|
#include <gtsam/linear/GaussianDensity.h>
|
||||||
|
|
|
@ -17,8 +17,8 @@ import numpy as np
|
||||||
from gtsam.symbol_shorthand import A, X
|
from gtsam.symbol_shorthand import A, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, GaussianConditional,
|
||||||
GaussianMixture, HybridBayesNet, HybridValues, noiseModel)
|
GaussianMixture, HybridBayesNet, HybridValues, noiseModel, VectorValues)
|
||||||
|
|
||||||
|
|
||||||
class TestHybridBayesNet(GtsamTestCase):
|
class TestHybridBayesNet(GtsamTestCase):
|
||||||
|
@ -53,9 +53,13 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
|
|
||||||
# Create values at which to evaluate.
|
# Create values at which to evaluate.
|
||||||
values = HybridValues()
|
values = HybridValues()
|
||||||
values.insert(asiaKey, 0)
|
continuous = VectorValues()
|
||||||
values.insert(X(0), [-6])
|
continuous.insert(X(0), [-6])
|
||||||
values.insert(X(1), [1])
|
continuous.insert(X(1), [1])
|
||||||
|
values.insert(continuous)
|
||||||
|
discrete = DiscreteValues()
|
||||||
|
discrete[asiaKey] = 0
|
||||||
|
values.insert(discrete)
|
||||||
|
|
||||||
conditionalProbability = conditional.evaluate(values.continuous())
|
conditionalProbability = conditional.evaluate(values.continuous())
|
||||||
mixtureProbability = conditional0.evaluate(values.continuous())
|
mixtureProbability = conditional0.evaluate(values.continuous())
|
||||||
|
@ -68,6 +72,26 @@ class TestHybridBayesNet(GtsamTestCase):
|
||||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||||
math.log(bayesNet.evaluate(values)))
|
math.log(bayesNet.evaluate(values)))
|
||||||
|
|
||||||
|
# Check invariance for all conditionals:
|
||||||
|
self.check_invariance(bayesNet.at(0).asGaussian(), continuous)
|
||||||
|
self.check_invariance(bayesNet.at(0).asGaussian(), values)
|
||||||
|
self.check_invariance(bayesNet.at(0), values)
|
||||||
|
|
||||||
|
self.check_invariance(bayesNet.at(1), values)
|
||||||
|
|
||||||
|
self.check_invariance(bayesNet.at(2).asDiscrete(), discrete)
|
||||||
|
self.check_invariance(bayesNet.at(2).asDiscrete(), values)
|
||||||
|
self.check_invariance(bayesNet.at(2), values)
|
||||||
|
|
||||||
|
def check_invariance(self, conditional, values):
|
||||||
|
"""Check invariance for given conditional."""
|
||||||
|
probability = conditional.evaluate(values)
|
||||||
|
self.assertTrue(probability >= 0.0)
|
||||||
|
logProb = conditional.logProbability(values)
|
||||||
|
self.assertAlmostEqual(probability, np.exp(logProb))
|
||||||
|
expected = conditional.logNormalizationConstant() - conditional.error(values)
|
||||||
|
self.assertAlmostEqual(logProb, expected)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue