From 9365a02bdb8cbdbab5aa978f0c52d736df5a4de9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 1 Nov 2022 14:01:20 -0400 Subject: [PATCH] add specific assignment error for GaussianMixtureFactor --- gtsam/hybrid/GaussianMixtureFactor.cpp | 8 ++++++++ gtsam/hybrid/GaussianMixtureFactor.h | 12 ++++++++++++ gtsam/hybrid/tests/testGaussianMixtureFactor.cpp | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index a8500911a..16802516e 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -107,4 +107,12 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( return errorTree; } +/* *******************************************************************************/ +double GaussianMixtureFactor::error( + const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const { + auto factor = factors_(discreteValues); + return factor->error(continuousVals); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 31ec3c1a0..b6552c078 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -137,6 +138,17 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { */ AlgebraicDecisionTree error(const VectorValues &continuousVals) const; + /** + * @brief Compute the error of this Gaussian Mixture given the continuous + * values and a discrete assignment. + * + * @param continuousVals The continuous values at which to compute the error. + * @param discreteValues The discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousVals, + const DiscreteValues &discreteValues) const; + /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index e6248f5c9..5c25a0931 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -182,6 +182,12 @@ TEST(GaussianMixtureFactor, Error) { AlgebraicDecisionTree expected_error(discrete_keys, errors); EXPECT(assert_equal(expected_error, error_tree)); + + // Test for single leaf given discrete assignment P(X|M,Z). + DiscreteValues discreteVals; + discreteVals[m1.first] = 1; + EXPECT_DOUBLES_EQUAL(4.0, mixtureFactor.error(continuousVals, discreteVals), + 1e-9); } /* ************************************************************************* */