diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index c1194d201..2c5aabf55 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -214,7 +214,12 @@ AlgebraicDecisionTree GaussianMixture::error( // functor to convert from GaussianConditional to double error value. auto errorFunc = [continuousVals](const GaussianConditional::shared_ptr &conditional) { - return conditional->error(continuousVals); + if (conditional) { + return conditional->error(continuousVals); + } else { + // return arbitrarily large error + return 1e50; + } }; DecisionTree errorTree(conditionals_, errorFunc); return errorTree; diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index cc27600f0..91d92ab0e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -145,4 +145,45 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { return gbn.optimize(); } +/* ************************************************************************* */ +double HybridBayesNet::error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const { + GaussianBayesNet gbn = this->choose(discreteValues); + return gbn.error(continuousValues); +} + +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesNet::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree error_tree; + + for (size_t idx = 0; idx < size(); idx++) { + AlgebraicDecisionTree conditional_error; + if (factors_.at(idx)->isHybrid()) { + // If factor is hybrid, select based on assignment. + GaussianMixture::shared_ptr gm = this->atMixture(idx); + conditional_error = gm->error(continuousValues); + + if (idx == 0) { + error_tree = conditional_error; + } else { + error_tree = error_tree + conditional_error; + } + + } else if (factors_.at(idx)->isContinuous()) { + // If continuous only, get the (double) error + // and add it to the error_tree + double error = this->atGaussian(idx)->error(continuousValues); + error_tree = error_tree.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (factors_.at(idx)->isDiscrete()) { + // If factor at `idx` is discrete-only, we skip. + continue; + } + } + + return error_tree; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index b8234d70a..82e890cc4 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -123,6 +123,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves. HybridBayesNet prune(size_t maxNrLeaves) const; + /** + * @brief 0.5 * sum of squared Mahalanobis distances + * for a specific discrete assignment. + * + * @param continuousValues Continuous values at which to compute the error. + * @param discreteValues Discrete assignment for a specific mode sequence. + * @return double + */ + double error(const VectorValues &continuousValues, + const DiscreteValues &discreteValues) const; + + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + /// @} private: diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 5885fdcdc..4ca760f88 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -183,6 +183,61 @@ TEST(HybridBayesNet, OptimizeMultifrontal) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test bayes net error +TEST(HybridBayesNet, Error) { + Switching s(3); + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + auto error_tree = hybridBayesNet->error(delta.continuous()); + + std::vector discrete_keys = {{M(1), 2}, {M(2), 2}}; + std::vector leaves = {0.0097568009, 3.3973404e-31, 0.029126214, + 0.0097568009}; + AlgebraicDecisionTree expected_error(discrete_keys, leaves); + + // regression + EXPECT(assert_equal(expected_error, error_tree, 1e-9)); + + // Error on pruned bayes net + auto prunedBayesNet = hybridBayesNet->prune(2); + auto pruned_error_tree = prunedBayesNet.error(delta.continuous()); + + std::vector pruned_leaves = {2e50, 3.3973404e-31, 2e50, 0.0097568009}; + AlgebraicDecisionTree expected_pruned_error(discrete_keys, + pruned_leaves); + + // regression + EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); + + // Verify error computation and check for specific error value + DiscreteValues discrete_values; + discrete_values[M(1)] = 1; + discrete_values[M(2)] = 1; + + double total_error = 0; + for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { + if (hybridBayesNet->at(idx)->isHybrid()) { + double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(), + discrete_values); + total_error += error; + } else if (hybridBayesNet->at(idx)->isContinuous()) { + double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous()); + total_error += error; + } + } + + EXPECT_DOUBLES_EQUAL( + total_error, hybridBayesNet->error(delta.continuous(), discrete_values), + 1e-9); + EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9); +} + /* ****************************************************************************/ // Test bayes net pruning TEST(HybridBayesNet, Prune) {