Add error and probPrime variants
parent
2d688a1986
commit
a4659f01c7
|
@ -468,12 +468,51 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
return error_tree;
|
return error_tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridGaussianFactorGraph::error(
|
||||||
|
const VectorValues &continuousValues,
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
double error = 0.0;
|
||||||
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
auto factor = factors_.at(idx);
|
||||||
|
|
||||||
|
if (factor->isHybrid()) {
|
||||||
|
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
error += c->asMixture()->error(continuousValues, discreteValues);
|
||||||
|
}
|
||||||
|
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||||
|
error += f->error(continuousValues, discreteValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (factor->isContinuous()) {
|
||||||
|
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||||
|
error += f->inner()->error(continuousValues);
|
||||||
|
}
|
||||||
|
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
error += cg->asGaussian()->error(continuousValues);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridGaussianFactorGraph::probPrime(
|
||||||
|
const VectorValues &continuousValues,
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
double error = this->error(continuousValues, discreteValues);
|
||||||
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
|
return std::exp(-error);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||||
AlgebraicDecisionTree<Key> prob_tree =
|
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||||
error_tree.apply([](double error) { return exp(-error); });
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
|
return exp(-error);
|
||||||
|
});
|
||||||
return prob_tree;
|
return prob_tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -182,6 +182,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error given a continuous vector values
|
||||||
|
* and a discrete assignment.
|
||||||
|
*
|
||||||
|
* @param continuousValues The continuous VectorValues
|
||||||
|
* for computing the error.
|
||||||
|
* @param discreteValues The specific discrete assignment
|
||||||
|
* whose error we wish to compute.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const VectorValues& continuousValues,
|
||||||
|
const DiscreteValues& discreteValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||||
* for each discrete assignment, and return as a tree.
|
* for each discrete assignment, and return as a tree.
|
||||||
|
@ -193,6 +206,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
AlgebraicDecisionTree<Key> probPrime(
|
AlgebraicDecisionTree<Key> probPrime(
|
||||||
const VectorValues& continuousValues) const;
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the unnormalized posterior probability for a continuous
|
||||||
|
* vector values given a specific assignment.
|
||||||
|
*
|
||||||
|
* @param continuousValues The vector values for which to compute the
|
||||||
|
* posterior probability.
|
||||||
|
* @param discreteValues The specific assignment to use for the computation.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double probPrime(const VectorValues& continuousValues,
|
||||||
|
const DiscreteValues& discreteValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||||
* eliminated after the continuous keys.
|
* eliminated after the continuous keys.
|
||||||
|
|
|
@ -180,6 +180,12 @@ class HybridGaussianFactorGraph {
|
||||||
void print(string s = "") const;
|
void print(string s = "") const;
|
||||||
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
|
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
// evaluation
|
||||||
|
double error(const gtsam::VectorValues& continuousValues,
|
||||||
|
const gtsam::DiscreteValues& discreteValues) const;
|
||||||
|
double probPrime(const gtsam::VectorValues& continuousValues,
|
||||||
|
const gtsam::DiscreteValues& discreteValues) const;
|
||||||
|
|
||||||
gtsam::HybridBayesNet* eliminateSequential();
|
gtsam::HybridBayesNet* eliminateSequential();
|
||||||
gtsam::HybridBayesNet* eliminateSequential(
|
gtsam::HybridBayesNet* eliminateSequential(
|
||||||
gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type);
|
||||||
|
|
|
@ -11,6 +11,7 @@ Author: Fan Jiang
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam.symbol_shorthand import C, M, X, Z
|
from gtsam.symbol_shorthand import C, M, X, Z
|
||||||
|
@ -110,7 +111,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
conditional1 = GaussianConditional.FromMeanAndStddev(
|
conditional1 = GaussianConditional.FromMeanAndStddev(
|
||||||
Z(i), I, X(0), [0], sigma=3
|
Z(i), I, X(0), [0], sigma=3
|
||||||
)
|
)
|
||||||
bayesNet.emplaceMixture([Z(i)], [X(0)], keys, [conditional0, conditional1])
|
bayesNet.emplaceMixture(
|
||||||
|
[Z(i)], [X(0)], keys, [conditional0, conditional1]
|
||||||
|
)
|
||||||
|
|
||||||
# Create prior on X(0).
|
# Create prior on X(0).
|
||||||
prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0)
|
prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0)
|
||||||
|
@ -136,7 +139,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
fg.push_back(factor)
|
fg.push_back(factor)
|
||||||
fg.push_back(bayesNet.atGaussian(1))
|
fg.push_back(bayesNet.atGaussian(1))
|
||||||
fg.push_back(bayesNet.atDiscrete(2))
|
fg.push_back(bayesNet.atDiscrete(2))
|
||||||
|
|
||||||
self.assertEqual(fg.size(), 3)
|
self.assertEqual(fg.size(), 3)
|
||||||
|
|
||||||
def test_tiny2(self):
|
def test_tiny2(self):
|
||||||
|
@ -156,8 +159,18 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
fg.push_back(factor)
|
fg.push_back(factor)
|
||||||
fg.push_back(bayesNet.atGaussian(2))
|
fg.push_back(bayesNet.atGaussian(2))
|
||||||
fg.push_back(bayesNet.atDiscrete(3))
|
fg.push_back(bayesNet.atDiscrete(3))
|
||||||
|
|
||||||
self.assertEqual(fg.size(), 4)
|
self.assertEqual(fg.size(), 4)
|
||||||
|
# Calculate ratio between Bayes net probability and the factor graph:
|
||||||
|
continuousValues = gtsam.VectorValues()
|
||||||
|
continuousValues.insert(X(0), sample.at(X(0)))
|
||||||
|
discreteValues = sample.discrete()
|
||||||
|
expected_ratio = bayesNet.evaluate(sample) / fg.probPrime(
|
||||||
|
continuousValues, discreteValues
|
||||||
|
)
|
||||||
|
print(expected_ratio)
|
||||||
|
|
||||||
|
# TODO(dellaert): Change the mode to 0 and calculate the ratio again.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue