Add error and probPrime variants

release/4.3a0
Frank Dellaert 2022-12-29 14:13:35 -05:00
parent 2d688a1986
commit a4659f01c7
4 changed files with 88 additions and 5 deletions

View File

@ -468,12 +468,51 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree; return error_tree;
} }
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
}
} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
}
}
}
return error;
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
// NOTE: The 0.5 term is handled by each factor
return std::exp(-error);
}
/* ************************************************************************ */ /* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime( AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues); AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree = AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
error_tree.apply([](double error) { return exp(-error); }); // NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return prob_tree; return prob_tree;
} }

View File

@ -182,6 +182,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/ */
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const; AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double
*/
double error(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
/** /**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree. * for each discrete assignment, and return as a tree.
@ -193,6 +206,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> probPrime( AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const; const VectorValues& continuousValues) const;
/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double
*/
double probPrime(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys. * eliminated after the continuous keys.

View File

@ -180,6 +180,12 @@ class HybridGaussianFactorGraph {
void print(string s = "") const; void print(string s = "") const;
bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const;
// evaluation
double error(const gtsam::VectorValues& continuousValues,
const gtsam::DiscreteValues& discreteValues) const;
double probPrime(const gtsam::VectorValues& continuousValues,
const gtsam::DiscreteValues& discreteValues) const;
gtsam::HybridBayesNet* eliminateSequential(); gtsam::HybridBayesNet* eliminateSequential();
gtsam::HybridBayesNet* eliminateSequential( gtsam::HybridBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type);

View File

@ -11,6 +11,7 @@ Author: Fan Jiang
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
import unittest import unittest
import math
import numpy as np import numpy as np
from gtsam.symbol_shorthand import C, M, X, Z from gtsam.symbol_shorthand import C, M, X, Z
@ -110,7 +111,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
conditional1 = GaussianConditional.FromMeanAndStddev( conditional1 = GaussianConditional.FromMeanAndStddev(
Z(i), I, X(0), [0], sigma=3 Z(i), I, X(0), [0], sigma=3
) )
bayesNet.emplaceMixture([Z(i)], [X(0)], keys, [conditional0, conditional1]) bayesNet.emplaceMixture(
[Z(i)], [X(0)], keys, [conditional0, conditional1]
)
# Create prior on X(0). # Create prior on X(0).
prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0) prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0)
@ -136,7 +139,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
fg.push_back(factor) fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(1)) fg.push_back(bayesNet.atGaussian(1))
fg.push_back(bayesNet.atDiscrete(2)) fg.push_back(bayesNet.atDiscrete(2))
self.assertEqual(fg.size(), 3) self.assertEqual(fg.size(), 3)
def test_tiny2(self): def test_tiny2(self):
@ -156,8 +159,18 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
fg.push_back(factor) fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(2)) fg.push_back(bayesNet.atGaussian(2))
fg.push_back(bayesNet.atDiscrete(3)) fg.push_back(bayesNet.atDiscrete(3))
self.assertEqual(fg.size(), 4) self.assertEqual(fg.size(), 4)
# Calculate ratio between Bayes net probability and the factor graph:
continuousValues = gtsam.VectorValues()
continuousValues.insert(X(0), sample.at(X(0)))
discreteValues = sample.discrete()
expected_ratio = bayesNet.evaluate(sample) / fg.probPrime(
continuousValues, discreteValues
)
print(expected_ratio)
# TODO(dellaert): Change the mode to 0 and calculate the ratio again.
if __name__ == "__main__": if __name__ == "__main__":