diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 3bafe5a9c..31177ddb7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const { return sample(&kRandomNumberGenerator); } +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesNet::error( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree result(0.0); + + // Iterate over each conditional. + for (auto &&conditional : *this) { + if (auto gm = conditional->asMixture()) { + // If conditional is hybrid, compute error for all assignments. + result = result + gm->error(continuousValues); + + } else if (auto gc = conditional->asGaussian()) { + // If continuous, get the error and add it to the result + double error = gc->error(continuousValues); + // Add the computed error to every leaf of the result tree. + result = result.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (auto dc = conditional->asDiscrete()) { + // If discrete, add the discrete error in the right branch + result = result.apply( + [dc](const Assignment &assignment, double leaf_value) { + return leaf_value + dc->error(DiscreteValues(assignment)); + }); + } + } + + return result; +} + /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::logProbability( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index e71cfe9b4..2934ef176 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -187,6 +187,16 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ + AlgebraicDecisionTree error(const VectorValues &continuousValues) const; + + /** + * @brief Compute log probability for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which + * to compute the log probability. + * @return AlgebraicDecisionTree + */ AlgebraicDecisionTree logProbability( const VectorValues &continuousValues) const; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 5248fce01..66985cc78 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) { *gbn.at(3))); } +/* ****************************************************************************/ +// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). +TEST(HybridBayesNet, Error) { + const auto continuousConditional = GaussianConditional::sharedMeanAndStddev( + X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0); + + const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)), + model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0)); + + const auto conditional0 = std::make_shared( + X(1), Vector1::Constant(5), I_1x1, model0), + conditional1 = std::make_shared( + X(1), Vector1::Constant(2), I_1x1, model1); + + auto gm = + new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}); + // Create hybrid Bayes net. + HybridBayesNet bayesNet; + bayesNet.push_back(continuousConditional); + bayesNet.emplace_back(gm); + bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1")); + + // Create values at which to evaluate. + HybridValues values; + values.insert(asiaKey, 0); + values.insert(X(0), Vector1(-6)); + values.insert(X(1), Vector1(1)); + + AlgebraicDecisionTree actual_errors = + bayesNet.error(values.continuous()); + + // Regression. + // Manually added all the error values from the 3 conditional types. + AlgebraicDecisionTree expected_errors( + {Asia}, std::vector{2.33005033585, 5.38619084965}); + + EXPECT(assert_equal(expected_errors, actual_errors)); +} + /* ****************************************************************************/ // Test Bayes net optimize TEST(HybridBayesNet, OptimizeAssignment) {