From a898ad3661e14c76d338476bfe46dfa94c7de8ed Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 29 Sep 2024 11:54:54 -0700 Subject: [PATCH] discretePosterior --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 25 ++++++++++---- gtsam/hybrid/HybridGaussianFactorGraph.h | 32 +++++++++++------- .../tests/testHybridGaussianFactorGraph.cpp | 33 +++++++++++-------- 3 files changed, 59 insertions(+), 31 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8a2a7fd15..8e6e95c17 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -42,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -342,14 +342,20 @@ static std::shared_ptr createHybridGaussianFactor( return std::make_shared(discreteSeparator, newFactors); } +/* *******************************************************************************/ +/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys. +static auto GetDiscreteKeys = + [](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys { + const std::set discreteKeySet = hfg.discreteKeys(); + return {discreteKeySet.begin(), discreteKeySet.end()}; +}; + /* *******************************************************************************/ std::pair> HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // Since we eliminate all continuous variables first, // the discrete separator will be *all* the discrete keys. - const std::set keysForDiscreteVariables = discreteKeys(); - DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(), - keysForDiscreteVariables.end()); + DiscreteKeys discreteSeparator = GetDiscreteKeys(*this); // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. @@ -525,14 +531,21 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { } /* ************************************************************************ */ -AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( +DecisionTreeFactor HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree = this->errorTree(continuousValues); AlgebraicDecisionTree prob_tree = error_tree.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); }); - return prob_tree; + return {GetDiscreteKeys(*this), prob_tree}; +} + +/* ************************************************************************ */ +DiscreteConditional HybridGaussianFactorGraph::discretePosterior( + const VectorValues &continuousValues) const { + auto p = probPrime(continuousValues); + return {p.size(), p}; } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 7e3aac663..5d19b4f83 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -187,17 +188,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree errorTree( const VectorValues& continuousValues) const; - /** - * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ - * for each discrete assignment, and return as a tree. - * - * @param continuousValues Continuous values at which to compute the - * probability. - * @return AlgebraicDecisionTree - */ - AlgebraicDecisionTree probPrime( - const VectorValues& continuousValues) const; - /** * @brief Compute the unnormalized posterior probability for a continuous * vector values given a specific assignment. @@ -206,6 +196,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ double probPrime(const HybridValues& values) const; + /** + * @brief Compute unnormalized probability \f$ P(X | M, Z) \f$ + * for each discrete assignment, and return as a tree. + * + * @param continuousValues Continuous values at which to compute probability. + * @return DecisionTreeFactor + */ + DecisionTreeFactor probPrime(const VectorValues& continuousValues) const; + + /** + * @brief Computer posterior P(M|X=x) when all continuous values X are given. + * This is very efficient as this simply probPrime normalized into a + * conditional. + * + * @param continuousValues Continuous values x to condition on. + * @return DecisionTreeFactor + */ + DiscreteConditional discretePosterior( + const VectorValues& continuousValues) const; + /** * @brief Create a decision tree of factor graphs out of this hybrid factor * graph. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 6aef60386..8ba1eb762 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -603,29 +603,34 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { /* ****************************************************************************/ // Test hybrid gaussian factor graph error and unnormalized probabilities TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { + // Create switching network with three continuous variables and two discrete: + // ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1) Switching s(3); - HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph; - HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); + const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(); - HybridValues delta = hybridBayesNet->optimize(); - auto error_tree = graph.errorTree(delta.continuous()); + const HybridValues delta = hybridBayesNet->optimize(); - std::vector discrete_keys = {{M(0), 2}, {M(1), 2}}; + // regression test for errorTree std::vector leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568}; - AlgebraicDecisionTree expected_error(discrete_keys, leaves); - - // regression - EXPECT(assert_equal(expected_error, error_tree, 1e-7)); + AlgebraicDecisionTree expectedErrors(s.modes, leaves); + const auto error_tree = graph.errorTree(delta.continuous()); + EXPECT(assert_equal(expectedErrors, error_tree, 1e-7)); + // regression test for probPrime + const DecisionTreeFactor expectedFactor( + s.modes, std::vector{0.36793249, 0.61247742, 0.59489556, 0.99029064}); auto probabilities = graph.probPrime(delta.continuous()); - std::vector prob_leaves = {0.36793249, 0.61247742, 0.59489556, - 0.99029064}; - AlgebraicDecisionTree expected_probabilities(discrete_keys, prob_leaves); + EXPECT(assert_equal(expectedFactor, probabilities, 1e-7)); - // regression - EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7)); + // regression test for discretePosterior + const DecisionTreeFactor normalized( + s.modes, std::vector{0.14341014, 0.23872714, 0.23187421, 0.38598852}); + DiscreteConditional expectedPosterior(2, normalized); + auto posterior = graph.discretePosterior(delta.continuous()); + EXPECT(assert_equal(expectedPosterior, posterior, 1e-7)); } /* ****************************************************************************/