diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8e6e95c17..0e5a34359 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -505,22 +505,22 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree(0.0); + AlgebraicDecisionTree result(0.0); // Iterate over each factor. for (auto &factor : factors_) { if (auto f = std::dynamic_pointer_cast(factor)) { // Check for HybridFactor, and call errorTree - error_tree = error_tree + f->errorTree(continuousValues); + result = result + f->errorTree(continuousValues); } else if (auto f = std::dynamic_pointer_cast(factor)) { // Skip discrete factors continue; } else { // Everything else is a continuous only factor HybridValues hv(continuousValues, DiscreteValues()); - error_tree = error_tree + AlgebraicDecisionTree(factor->error(hv)); + result = result + AlgebraicDecisionTree(factor->error(hv)); } } - return error_tree; + return result; } /* ************************************************************************ */ @@ -531,21 +531,14 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { } /* ************************************************************************ */ -DecisionTreeFactor HybridGaussianFactorGraph::probPrime( +AlgebraicDecisionTree HybridGaussianFactorGraph::discretePosterior( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree = this->errorTree(continuousValues); - AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { + AlgebraicDecisionTree errors = this->errorTree(continuousValues); + AlgebraicDecisionTree p = errors.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); }); - return {GetDiscreteKeys(*this), prob_tree}; -} - -/* ************************************************************************ */ -DiscreteConditional HybridGaussianFactorGraph::discretePosterior( - const VectorValues &continuousValues) const { - auto p = probPrime(continuousValues); - return {p.size(), p}; + return p / p.sum(); } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 5d19b4f83..3ef6218be 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -196,24 +196,17 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ double probPrime(const HybridValues& values) const; - /** - * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ - * for each discrete assignment, and return as a tree. - * - * @param continuousValues Continuous values at which to compute probability. - * @return DecisionTreeFactor - */ - DecisionTreeFactor probPrime(const VectorValues& continuousValues) const; - /** * @brief Computer posterior P(M|X=x) when all continuous values X are given. - * This is very efficient as this simply probPrime normalized into a - * conditional. + * This is efficient as this simply probPrime normalized. + * + * @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys, + * which we would need, are hard to recover. * * @param continuousValues Continuous values x to condition on. * @return DecisionTreeFactor */ - DiscreteConditional discretePosterior( + AlgebraicDecisionTree discretePosterior( const VectorValues& continuousValues) const; /** diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 8ba1eb762..0c5f52e61 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -619,16 +619,9 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { const auto error_tree = graph.errorTree(delta.continuous()); EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); - // regression test for probPrime - const DecisionTreeFactor expectedFactor( - s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064}); - auto probabilities = graph.probPrime(delta.continuous()); - EXPECT(assert_equal(expectedFactor, probabilities, 1e-7)); - // regression test for discretePosterior - const DecisionTreeFactor normalized( + const AlgebraicDecisionTree expectedPosterior( s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); - DiscreteConditional expectedPosterior(2, normalized); auto posterior = graph.discretePosterior(delta.continuous()); EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); }