diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 425b92ff3..983817f03 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -511,6 +511,14 @@ double HybridGaussianFactorGraph::error( return error; } +/* ************************************************************************ */ +double HybridGaussianFactorGraph::probPrime( + const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + double error = this->error(continuousValues, discreteValues); + return std::exp(-error); +} + /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index e2c8863ea..88728b6bb 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -204,6 +204,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree probPrime( const VectorValues& continuousValues) const; + /** + * @brief Compute the unnormalized posterior probability for a continuous + * vector values given a specific assignment. + * + * @param continuousValues The vector values for which to compute the + * posterior probability. + * @param discreteValues The specific assignment to use for the computation. + * @return double + */ + double probPrime(const VectorValues& continuousValues, + const DiscreteValues& discreteValues) const; + /** * @brief Compute the VectorValues solution for the continuous variables for * each mode. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 98d6dc870..b56b6b62a 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -581,7 +581,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { EXPECT(assert_equal(expected_error, error, 1e-9)); double probs = exp(-error); - double expected_probs = exp(-expected_error); + double expected_probs = graph.probPrime(delta.continuous(), delta.discrete()); // regression EXPECT(assert_equal(expected_probs, probs, 1e-7));