From cb55af3a81093c05a4b41b7abec52d804050b86d Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 8 Nov 2022 14:20:51 -0500 Subject: [PATCH] separate HybridGaussianFactorGraph::error() using both continuous and discrete values --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 53 +++++++++++-------- gtsam/hybrid/HybridGaussianFactorGraph.h | 13 +++++ .../tests/testHybridGaussianFactorGraph.cpp | 25 +++++++++ 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 62d681665..425b92ff3 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -483,6 +483,34 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( return error_tree; } +/* ************************************************************************ */ +double HybridGaussianFactorGraph::error( + const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + double error = 0.0; + for (size_t idx = 0; idx < size(); idx++) { + auto factor = factors_.at(idx); + + if (factor->isHybrid()) { + if (auto c = boost::dynamic_pointer_cast(factor)) { + error += c->asMixture()->error(continuousValues, discreteValues); + } + if (auto f = boost::dynamic_pointer_cast(factor)) { + error += f->error(continuousValues, discreteValues); + } + + } else if (factor->isContinuous()) { + if (auto f = boost::dynamic_pointer_cast(factor)) { + error += f->inner()->error(continuousValues); + } + if (auto cg = boost::dynamic_pointer_cast(factor)) { + error += cg->asGaussian()->error(continuousValues); + } + } + } + return error; +} + /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( const VectorValues &continuousValues) const { @@ -539,32 +567,11 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( continue; } - double error = 0.0; // Compute the error given the delta and the assignment. - for (size_t idx = 0; idx < size(); idx++) { - auto factor = factors_.at(idx); - - if (factor->isHybrid()) { - if (auto c = boost::dynamic_pointer_cast(factor)) { - error += c->asMixture()->error(delta, assignment); - } - if (auto f = - boost::dynamic_pointer_cast(factor)) { - error += f->error(delta, assignment); - } - - } else if (factor->isContinuous()) { - if (auto f = - boost::dynamic_pointer_cast(factor)) { - error += f->inner()->error(delta); - } - if (auto cg = boost::dynamic_pointer_cast(factor)) { - error += cg->asGaussian()->error(delta); - } - } - } + double error = this->error(delta, assignment); probPrimes.push_back(exp(-error)); } + AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); return probPrimeTree; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 1198cc8bc..e2c8863ea 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -180,6 +180,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph */ AlgebraicDecisionTree error(const VectorValues& continuousValues) const; + /** + * @brief Compute error given a continuous vector values + * and a discrete assignment. + * + * @param continuousValues The continuous VectorValues + * for computing the error. + * @param discreteValues The specific discrete assignment + * whose error we wish to compute. + * @return double + */ + double error(const VectorValues& continuousValues, + const DiscreteValues& discreteValues) const; + /** * @brief Compute unnormalized probability for each discrete assignment, * and return as a tree. diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 7877461b6..98d6dc870 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -569,6 +569,31 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) { HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + Ordering hybridOrdering = graph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + graph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + double error = graph.error(delta.continuous(), delta.discrete()); + + double expected_error = 0.490243199; + // regression + EXPECT(assert_equal(expected_error, error, 1e-9)); + + double probs = exp(-error); + double expected_probs = exp(-expected_error); + + // regression + EXPECT(assert_equal(expected_probs, probs, 1e-7)); +} + +/* ****************************************************************************/ +// Test hybrid gaussian factor graph error and unnormalized probabilities +TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { + Switching s(3); + + HybridGaussianFactorGraph graph = s.linearizedFactorGraph; + Ordering hybridOrdering = graph.getHybridOrdering(); HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential(hybridOrdering);