diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 314d4fc63..c0815b2d7 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -210,13 +210,14 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { /* *******************************************************************************/ AlgebraicDecisionTree GaussianMixture::error( const VectorValues &continuousValues) const { - // functor to convert from GaussianConditional to double error value. + // functor to calculate to double error value from GaussianConditional. auto errorFunc = [continuousValues](const GaussianConditional::shared_ptr &conditional) { if (conditional) { return conditional->error(continuousValues); } else { - // return arbitrarily large error + // Return arbitrarily large error if conditional is null + // Conditional is null if it is pruned out. return 1e50; } }; @@ -227,6 +228,7 @@ AlgebraicDecisionTree GaussianMixture::error( /* *******************************************************************************/ double GaussianMixture::error(const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { + // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(discreteValues); return conditional->error(continuousValues); } diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index f070fe07a..fd437f52c 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -112,6 +112,7 @@ AlgebraicDecisionTree GaussianMixtureFactor::error( double GaussianMixtureFactor::error( const VectorValues &continuousValues, const DiscreteValues &discreteValues) const { + // Directly index to get the conditional, no need to build the whole tree. auto factor = factors_(discreteValues); return factor->error(continuousValues); } diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f0d53c416..48c4b6d50 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -244,13 +244,16 @@ AlgebraicDecisionTree HybridBayesNet::error( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree; + // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { AlgebraicDecisionTree conditional_error; + if (factors_.at(idx)->isHybrid()) { - // If factor is hybrid, select based on assignment. + // If factor is hybrid, select based on assignment and compute error. GaussianMixture::shared_ptr gm = this->atMixture(idx); conditional_error = gm->error(continuousValues); + // Assign for the first index, add error for subsequent ones. if (idx == 0) { error_tree = conditional_error; } else { @@ -261,6 +264,7 @@ AlgebraicDecisionTree HybridBayesNet::error( // If continuous only, get the (double) error // and add it to the error_tree double error = this->atGaussian(idx)->error(continuousValues); + // Add the computed error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); @@ -273,6 +277,7 @@ AlgebraicDecisionTree HybridBayesNet::error( return error_tree; } +/* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::probPrime( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree = this->error(continuousValues); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d6937957f..32653bdec 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -428,6 +428,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( const VectorValues &continuousValues) const { AlgebraicDecisionTree error_tree(0.0); + // Iterate over each factor. for (size_t idx = 0; idx < size(); idx++) { AlgebraicDecisionTree factor_error; @@ -435,8 +436,10 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // If factor is hybrid, select based on assignment. GaussianMixtureFactor::shared_ptr gaussianMixture = boost::static_pointer_cast(factors_.at(idx)); + // Compute factor error. factor_error = gaussianMixture->error(continuousValues); + // If first factor, assign error, else add it. if (idx == 0) { error_tree = factor_error; } else { @@ -450,7 +453,9 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( boost::static_pointer_cast(factors_.at(idx)); GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner(); + // Compute the error of the gaussian factor. double error = gaussian->error(continuousValues); + // Add the gaussian factor error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 58a915d57..f29a84022 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor { * elements based on the number of discrete keys and the cardinality of the * keys, so that the decision tree is constructed appropriately. * - * @tparam FACTOR The type of the factor shared pointers being passed in. Will - * be typecast to NonlinearFactor shared pointers. + * @tparam FACTOR The type of the factor shared pointers being passed in. + * Will be typecast to NonlinearFactor shared pointers. * @param keys Vector of keys for continuous factors. * @param discreteKeys Vector of discrete keys. - * @param factors Vector of shared pointers to factors. + * @param factors Vector of nonlinear factors. * @param normalized Flag indicating if the factor error is already * normalized. */ diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 90c76593e..899c129e0 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -196,8 +196,10 @@ class HybridNonlinearFactorGraph { #include class MixtureFactor : gtsam::HybridFactor { - MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, - const gtsam::DecisionTree& factors, bool normalized = false); + MixtureFactor( + const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, + const gtsam::DecisionTree& factors, + bool normalized = false); template MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys, diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 556a5f16a..310081f02 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -104,7 +104,7 @@ TEST(GaussianMixture, Error) { X(2), S2, model); // Create decision tree - DiscreteKey m1(1, 2); + DiscreteKey m1(M(1), 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); @@ -115,12 +115,19 @@ TEST(GaussianMixture, Error) { values.insert(X(2), Vector2::Zero()); auto error_tree = mixture.error(values); + // regression std::vector discrete_keys = {m1}; std::vector leaves = {0.5, 4.3252595}; AlgebraicDecisionTree expected_error(discrete_keys, leaves); - // regression EXPECT(assert_equal(expected_error, error_tree, 1e-6)); + + // Regression for non-tree version. + DiscreteValues assignment; + assignment[M(1)] = 0; + EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8); + assignment[M(1)] = 1; + EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 14e1b8dad..ba0622ff9 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -178,6 +178,7 @@ TEST(GaussianMixtureFactor, Error) { AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); std::vector discrete_keys = {m1}; + // Error values for regression test std::vector errors = {1, 4}; AlgebraicDecisionTree expected_error(discrete_keys, errors); diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 8b8ca976b..3593e1952 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -216,8 +216,7 @@ TEST(HybridBayesNet, Error) { // Verify error computation and check for specific error value DiscreteValues discrete_values; - discrete_values[M(0)] = 1; - discrete_values[M(1)] = 1; + insert(discrete_values)(M(0), 1)(M(1), 1); double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 5167f6ff6..fe3212eda 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -41,7 +41,8 @@ TEST(MixtureFactor, Constructor) { CHECK(it == factor.end()); } - +/* ************************************************************************* */ +// Test .print() output. TEST(MixtureFactor, Printing) { DiscreteKey m1(1, 2); double between0 = 0.0;