diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 181b1e6a5..a8500911a 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -95,4 +95,16 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() }; return {factors_, wrap}; } + +/* *******************************************************************************/ +AlgebraicDecisionTree GaussianMixtureFactor::error( + const VectorValues &continuousVals) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = [continuousVals](const GaussianFactor::shared_ptr &factor) { + return factor->error(continuousVals); + }; + DecisionTree errorTree(factors_, errorFunc); + return errorTree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index f27c49180..31ec3c1a0 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -20,6 +20,7 @@ #pragma once +#include #include #include #include @@ -131,17 +132,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @brief Compute error of the GaussianMixtureFactor as a tree. * * @param continuousVals The continuous VectorValues. - * @return DecisionTree A decision tree with corresponding keys + * @return AlgebraicDecisionTree A decision tree with corresponding keys * as the factor but leaf values as the error. */ - DecisionTree error(const VectorValues &c) const { - // functor to convert from sharedFactor to double error value. - auto errorFunc = [c](const GaussianFactor::shared_ptr &factor) { - return factor->error(c); - }; - DecisionTree errorTree(factors_, errorFunc); - return errorTree; - } + AlgebraicDecisionTree error(const VectorValues &continuousVals) const; /// Add MixtureFactor to a Sum, syntactic sugar. friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index cb9068c30..e6248f5c9 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file GaussianMixtureFactor.cpp + * @file testGaussianMixtureFactor.cpp * @brief Unit tests for GaussianMixtureFactor * @author Varun Agrawal * @author Fan Jiang @@ -135,7 +135,7 @@ TEST(GaussianMixtureFactor, Printing) { EXPECT(assert_print_equal(expected, mixtureFactor)); } -TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { +TEST(GaussianMixtureFactor, GaussianMixture) { KeyVector keys; keys.push_back(X(0)); keys.push_back(X(1)); @@ -151,6 +151,39 @@ TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size()); } +/* ************************************************************************* */ +// Test the error of the GaussianMixtureFactor +TEST(GaussianMixtureFactor, Error) { + DiscreteKey m1(1, 2); + + auto A01 = Matrix2::Identity(); + auto A02 = Matrix2::Identity(); + + auto A11 = Matrix2::Identity(); + auto A12 = Matrix2::Identity() * 2; + + auto b = Vector2::Zero(); + + auto f0 = boost::make_shared(X(1), A01, X(2), A02, b); + auto f1 = boost::make_shared(X(1), A11, X(2), A12, b); + std::vector factors{f0, f1}; + + GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + VectorValues continuousVals; + continuousVals.insert(X(1), Vector2(0, 0)); + continuousVals.insert(X(2), Vector2(1, 1)); + + // error should return a tree of errors, with nodes for each discrete value. + AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousVals); + + std::vector discrete_keys = {m1}; + std::vector errors = {1, 4}; + AlgebraicDecisionTree expected_error(discrete_keys, errors); + + EXPECT(assert_equal(expected_error, error_tree)); +} + /* ************************************************************************* */ int main() { TestResult tr;