diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 32653bdec..b110f8586 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -468,12 +468,51 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( return error_tree; } +/* ************************************************************************ */ +double HybridGaussianFactorGraph::error( + const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + double error = 0.0; + for (size_t idx = 0; idx < size(); idx++) { + auto factor = factors_.at(idx); + + if (factor->isHybrid()) { + if (auto c = boost::dynamic_pointer_cast(factor)) { + error += c->asMixture()->error(continuousValues, discreteValues); + } + if (auto f = boost::dynamic_pointer_cast(factor)) { + error += f->error(continuousValues, discreteValues); + } + + } else if (factor->isContinuous()) { + if (auto f = boost::dynamic_pointer_cast(factor)) { + error += f->inner()->error(continuousValues); + } + if (auto cg = boost::dynamic_pointer_cast(factor)) { + error += cg->asGaussian()->error(continuousValues); + } + } + } + return error; +} + +/* ************************************************************************ */ +double HybridGaussianFactorGraph::probPrime( + const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + double error = this->error(continuousValues, discreteValues); + // NOTE: The 0.5 term is handled by each factor + return std::exp(-error); +} + /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree = this->error(continuousValues); - AlgebraicDecisionTree prob_tree = - error_tree.apply([](double error) { return exp(-error); }); + AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { + // NOTE: The 0.5 term is handled by each factor + return exp(-error); + }); return prob_tree; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index ac9ae1a46..9de18b6af 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -182,6 +182,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + /** + * @brief Compute error given a continuous vector values + * and a discrete assignment. + * + * @param continuousValues The continuous VectorValues + * for computing the error. + * @param discreteValues The specific discrete assignment + * whose error we wish to compute. + * @return double + */ + double error(const VectorValues& continuousValues, + const DiscreteValues& discreteValues) const; + /** * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ * for each discrete assignment, and return as a tree. @@ -193,6 +206,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree probPrime( const VectorValues& continuousValues) const; + /** + * @brief Compute the unnormalized posterior probability for a continuous + * vector values given a specific assignment. + * + * @param continuousValues The vector values for which to compute the + * posterior probability. + * @param discreteValues The specific assignment to use for the computation. + * @return double + */ + double probPrime(const VectorValues& continuousValues, + const DiscreteValues& discreteValues) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 29247cdc3..3dbf5d542 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -180,6 +180,12 @@ class HybridGaussianFactorGraph { void print(string s = "") const; bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; + // evaluation + double error(const gtsam::VectorValues& continuousValues, + const gtsam::DiscreteValues& discreteValues) const; + double probPrime(const gtsam::VectorValues& continuousValues, + const gtsam::DiscreteValues& discreteValues) const; + gtsam::HybridBayesNet* eliminateSequential(); gtsam::HybridBayesNet* eliminateSequential( gtsam::Ordering::OrderingType type); diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 37243b937..2ebc87971 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -11,6 +11,7 @@ Author: Fan Jiang # pylint: disable=invalid-name, no-name-in-module, no-member import unittest +import math import numpy as np from gtsam.symbol_shorthand import C, M, X, Z @@ -110,7 +111,9 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): conditional1 = GaussianConditional.FromMeanAndStddev( Z(i), I, X(0), [0], sigma=3 ) - bayesNet.emplaceMixture([Z(i)], [X(0)], keys, [conditional0, conditional1]) + bayesNet.emplaceMixture( + [Z(i)], [X(0)], keys, [conditional0, conditional1] + ) # Create prior on X(0). prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0) @@ -136,7 +139,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): fg.push_back(factor) fg.push_back(bayesNet.atGaussian(1)) fg.push_back(bayesNet.atDiscrete(2)) - + self.assertEqual(fg.size(), 3) def test_tiny2(self): @@ -156,8 +159,18 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): fg.push_back(factor) fg.push_back(bayesNet.atGaussian(2)) fg.push_back(bayesNet.atDiscrete(3)) - + self.assertEqual(fg.size(), 4) + # Calculate ratio between Bayes net probability and the factor graph: + continuousValues = gtsam.VectorValues() + continuousValues.insert(X(0), sample.at(X(0))) + discreteValues = sample.discrete() + expected_ratio = bayesNet.evaluate(sample) / fg.probPrime( + continuousValues, discreteValues + ) + print(expected_ratio) + + # TODO(dellaert): Change the mode to 0 and calculate the ratio again. if __name__ == "__main__":