Make sure all conditional methods can be called in wrappers and satisfy invariants there, as well.
parent
bead5ce4da
commit
51c46410dc
|
@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
|||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||
DiscreteConditional();
|
||||
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
|
||||
|
@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
|
||||
const gtsam::DecisionTreeFactor& marginal,
|
||||
const gtsam::Ordering& orderedKeys);
|
||||
|
||||
// Standard interface
|
||||
double logNormalizationConstant() const;
|
||||
double logProbability(const gtsam::DiscreteValues& values) const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
double error(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteConditional operator*(
|
||||
const gtsam::DiscreteConditional& other) const;
|
||||
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
|
||||
|
@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
size_t sample(size_t value) const;
|
||||
size_t sample() const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
|
||||
// Markdown and HTML
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
|
@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
string html(const gtsam::KeyFormatter& keyFormatter,
|
||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
|
||||
// Expose HybridValues versions
|
||||
double logProbability(const gtsam::HybridValues& x) const;
|
||||
double evaluate(const gtsam::HybridValues& x) const;
|
||||
double error(const gtsam::HybridValues& x) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
|
|
|
@ -61,6 +61,7 @@ virtual class HybridConditional {
|
|||
size_t nrParents() const;
|
||||
|
||||
// Standard interface:
|
||||
double logNormalizationConstant() const;
|
||||
double logProbability(const gtsam::HybridValues& values) const;
|
||||
double evaluate(const gtsam::HybridValues& values) const;
|
||||
double operator()(const gtsam::HybridValues& values) const;
|
||||
|
|
|
@ -456,6 +456,7 @@ class GaussianFactorGraph {
|
|||
};
|
||||
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
virtual class GaussianConditional : gtsam::JacobianFactor {
|
||||
// Constructors
|
||||
GaussianConditional(size_t key, Vector d, Matrix R,
|
||||
|
@ -497,6 +498,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
|||
bool equals(const gtsam::GaussianConditional& cg, double tol) const;
|
||||
|
||||
// Standard Interface
|
||||
double logNormalizationConstant() const;
|
||||
double logProbability(const gtsam::VectorValues& x) const;
|
||||
double evaluate(const gtsam::VectorValues& x) const;
|
||||
double error(const gtsam::VectorValues& x) const;
|
||||
|
@ -518,6 +520,11 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
|||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
|
||||
// Expose HybridValues versions
|
||||
double logProbability(const gtsam::HybridValues& x) const;
|
||||
double evaluate(const gtsam::HybridValues& x) const;
|
||||
double error(const gtsam::HybridValues& x) const;
|
||||
};
|
||||
|
||||
#include <gtsam/linear/GaussianDensity.h>
|
||||
|
|
|
@ -17,8 +17,8 @@ import numpy as np
|
|||
from gtsam.symbol_shorthand import A, X
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
||||
GaussianMixture, HybridBayesNet, HybridValues, noiseModel)
|
||||
from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, GaussianConditional,
|
||||
GaussianMixture, HybridBayesNet, HybridValues, noiseModel, VectorValues)
|
||||
|
||||
|
||||
class TestHybridBayesNet(GtsamTestCase):
|
||||
|
@ -53,9 +53,13 @@ class TestHybridBayesNet(GtsamTestCase):
|
|||
|
||||
# Create values at which to evaluate.
|
||||
values = HybridValues()
|
||||
values.insert(asiaKey, 0)
|
||||
values.insert(X(0), [-6])
|
||||
values.insert(X(1), [1])
|
||||
continuous = VectorValues()
|
||||
continuous.insert(X(0), [-6])
|
||||
continuous.insert(X(1), [1])
|
||||
values.insert(continuous)
|
||||
discrete = DiscreteValues()
|
||||
discrete[asiaKey] = 0
|
||||
values.insert(discrete)
|
||||
|
||||
conditionalProbability = conditional.evaluate(values.continuous())
|
||||
mixtureProbability = conditional0.evaluate(values.continuous())
|
||||
|
@ -68,6 +72,26 @@ class TestHybridBayesNet(GtsamTestCase):
|
|||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||
math.log(bayesNet.evaluate(values)))
|
||||
|
||||
# Check invariance for all conditionals:
|
||||
self.check_invariance(bayesNet.at(0).asGaussian(), continuous)
|
||||
self.check_invariance(bayesNet.at(0).asGaussian(), values)
|
||||
self.check_invariance(bayesNet.at(0), values)
|
||||
|
||||
self.check_invariance(bayesNet.at(1), values)
|
||||
|
||||
self.check_invariance(bayesNet.at(2).asDiscrete(), discrete)
|
||||
self.check_invariance(bayesNet.at(2).asDiscrete(), values)
|
||||
self.check_invariance(bayesNet.at(2), values)
|
||||
|
||||
def check_invariance(self, conditional, values):
|
||||
"""Check invariance for given conditional."""
|
||||
probability = conditional.evaluate(values)
|
||||
self.assertTrue(probability >= 0.0)
|
||||
logProb = conditional.logProbability(values)
|
||||
self.assertAlmostEqual(probability, np.exp(logProb))
|
||||
expected = conditional.logNormalizationConstant() - conditional.error(values)
|
||||
self.assertAlmostEqual(logProb, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue