From ca14b7e6ece6bfb5dbb21ce7f9024de6c7567a76 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 1 Nov 2022 20:19:36 -0400 Subject: [PATCH] GaussianMixture error methods --- gtsam/hybrid/GaussianMixture.cpp | 19 ++++++++++ gtsam/hybrid/GaussianMixture.h | 20 +++++++++++ gtsam/hybrid/tests/testGaussianMixture.cpp | 40 ++++++++++++++++++++-- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 5172a9798..c1194d201 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -208,4 +208,23 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { conditionals_.root_ = pruned_conditionals.root_; } +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixture::error( + const VectorValues &continuousVals) const { + // functor to convert from GaussianConditional to double error value. + auto errorFunc = + [continuousVals](const GaussianConditional::shared_ptr &conditional) { + return conditional->error(continuousVals); + }; + DecisionTree errorTree(conditionals_, errorFunc); + return errorTree; +} + +/* *******************************************************************************/ +double GaussianMixture::error(const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const { + auto conditional = conditionals_(discreteValues); + return conditional->error(continuousVals); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 9792a8532..b3b47fc87 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -143,6 +143,26 @@ class GTSAM_EXPORT GaussianMixture /// Getter for the underlying Conditionals DecisionTree const Conditionals &conditionals(); + /** + * @brief Compute error of the GaussianMixture as a tree. + * + * @param continuousVals The continuous VectorValues. + * @return AlgebraicDecisionTree A decision tree with corresponding keys + * as the factor but leaf values as the error. + */ + AlgebraicDecisionTree error(const VectorValues &continuousVals) const; + + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param continuousVals The continuous values at which to compute the error. + * @param discreteValues The discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const; + /** * @brief Prune the decision tree of Gaussian factors as per the discrete * `decisionTree`. diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 420e22315..556a5f16a 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -78,15 +78,51 @@ TEST(GaussianMixture, Equals) { GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals); + GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); // Let's check that this worked: DiscreteValues mode; mode[m1.first] = 1; - auto actual = mixtureFactor(mode); + auto actual = mixture(mode); EXPECT(actual == conditional1); } +/* ************************************************************************* */ +/// Test error method of GaussianMixture. +TEST(GaussianMixture, Error) { + Matrix22 S1 = Matrix22::Identity(); + Matrix22 S2 = Matrix22::Identity() * 2; + Matrix22 R1 = Matrix22::Ones(); + Matrix22 R2 = Matrix22::Ones(); + Vector2 d1(1, 2), d2(2, 1); + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + auto conditional0 = boost::make_shared(X(1), d1, R1, + X(2), S1, model), + conditional1 = boost::make_shared(X(1), d2, R2, + X(2), S2, model); + + // Create decision tree + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + GaussianMixture mixture({X(1)}, {X(2)}, {m1}, conditionals); + + VectorValues values; + values.insert(X(1), Vector2::Ones()); + values.insert(X(2), Vector2::Zero()); + auto error_tree = mixture.error(values); + + std::vector discrete_keys = {m1}; + std::vector leaves = {0.5, 4.3252595}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-6)); +} + /* ************************************************************************* */ int main() { TestResult tr;