diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 5f655c990..9df0012c7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -194,40 +194,6 @@ HybridValues HybridBayesNet::sample() const { return sample(&kRandomNumberGenerator); } -/* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::errorTree( - const VectorValues &continuousValues) const { - AlgebraicDecisionTree result(0.0); - - // Iterate over each conditional. - for (auto &&conditional : *this) { - if (auto gm = conditional->asHybrid()) { - // If conditional is hybrid, compute error for all assignments. - result = result + gm->errorTree(continuousValues); - - } else if (auto gc = conditional->asGaussian()) { - // If continuous, get the error and add it to the result - double error = gc->error(continuousValues); - // Add the computed error to every leaf of the result tree. - result = result.apply( - [error](double leaf_value) { return leaf_value + error; }); - - } else if (auto dc = conditional->asDiscrete()) { - // If discrete, add the discrete error in the right branch - if (result.nrLeaves() == 1) { - result = dc->errorTree(); - } else { - result = result.apply( - [dc](const Assignment &assignment, double leaf_value) { - return leaf_value + dc->error(DiscreteValues(assignment)); - }); - } - } - } - - return result; -} - /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::logDiscretePosteriorPrime( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 9e621ea20..fba6bb6aa 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -210,16 +210,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ HybridBayesNet prune(size_t maxNrLeaves) const; - /** - * @brief Compute conditional error for each discrete assignment, - * and return as a tree. - * - * @param continuousValues Continuous values at which to compute the error. - * @return AlgebraicDecisionTree - */ - AlgebraicDecisionTree errorTree( - const VectorValues &continuousValues) const; - /** * @brief Error method using HybridValues which returns specific error for * assignment. diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index ee47a698a..9974827e8 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -95,12 +95,6 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { EXPECT(assert_equal(bayesNet, bayesNet.prune(2))); EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size()); - // errorTree - AlgebraicDecisionTree actual = bayesNet.errorTree({}); - AlgebraicDecisionTree expectedErrorTree( - {Asia}, std::vector{-log(0.4), -log(0.6)}); - EXPECT(assert_equal(expectedErrorTree, actual)); - // error EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); @@ -127,20 +121,73 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) { /* ****************************************************************************/ // Test creation of a tiny hybrid Bayes net. TEST(HybridBayesNet, Tiny) { - auto bn = tiny::createHybridBayesNet(); - EXPECT_LONGS_EQUAL(3, bn.size()); + auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode) + EXPECT_LONGS_EQUAL(3, bayesNet.size()); const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; - auto fg = bn.toFactorGraph(vv); - EXPECT_LONGS_EQUAL(3, fg.size()); + HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; - // Check that the ratio of probPrime to evaluate is the same for all modes. - std::vector ratio(2); - for (size_t mode : {0, 1}) { - const HybridValues hv{vv, {{M(0), mode}}}; - ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv); - } - EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + // choose + HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); + GaussianBayesNet chosen; + chosen.push_back(hgc->choose(zero.discrete())); + chosen.push_back(bayesNet.at(1)->asGaussian()); + EXPECT(assert_equal(chosen, bayesNet.choose(zero.discrete()), 1e-9)); + + // logProbability + const double logP0 = chosen.logProbability(vv) + log(0.4); // 0.4 is prior + const double logP1 = bayesNet.choose(one.discrete()).logProbability(vv) + log(0.6); // 0.6 is prior + EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); + EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); + + // evaluate + EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9); + + // optimize + EXPECT(assert_equal(one, bayesNet.optimize())); + EXPECT(assert_equal(chosen.optimize(), bayesNet.optimize(zero.discrete()))); + + // sample + std::mt19937_64 rng(42); + EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete())); + + // prune + auto pruned = bayesNet.prune(1); + EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); + EXPECT(!pruned.equals(bayesNet)); + + // // error + // EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); + // EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); + + // logDiscretePosteriorPrime, TODO: useless as -errorTree? + AlgebraicDecisionTree expected(M(0), logP0, logP1); + EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv))); + + // // logProbability + // EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); + // EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); + + // // discretePosterior + // AlgebraicDecisionTree expectedPosterior({Asia}, + // std::vector{0.4, + // 0.6}); + // EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({}))); + + // // toFactorGraph + // HybridGaussianFactorGraph expectedFG{}; + + // auto fg = bayesNet.toFactorGraph(vv); + // EXPECT_LONGS_EQUAL(3, fg.size()); + // EXPECT(assert_equal(expectedFG, fg)); + + // // Check that the ratio of probPrime to evaluate is the same for all modes. + // std::vector ratio(2); + // ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero); + // ratio[0] = std::exp(-fg.error(one)) / bayesNet.evaluate(one); + // EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); + + // TODO: better test: check if discretePosteriors agree ! } /* ****************************************************************************/ @@ -174,21 +221,6 @@ TEST(HybridBayesNet, evaluateHybrid) { bayesNet.evaluate(values), 1e-9); } -/* ****************************************************************************/ -// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). -TEST(HybridBayesNet, Error) { - using namespace different_sigmas; - - AlgebraicDecisionTree actual = bayesNet.errorTree(values.continuous()); - - // Regression. - // Manually added all the error values from the 3 conditional types. - AlgebraicDecisionTree expected( - {Asia}, std::vector{2.33005033585, 5.38619084965}); - - EXPECT(assert_equal(expected, actual)); -} - /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { diff --git a/gtsam/hybrid/tests/testHybridGaussianFactor.cpp b/gtsam/hybrid/tests/testHybridGaussianFactor.cpp index c2ffe24c8..5ff8c1478 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactor.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactor.cpp @@ -357,16 +357,9 @@ TEST(HybridGaussianFactor, DifferentCovariancesFG) { cv.insert(X(0), Vector1(0.0)); cv.insert(X(1), Vector1(0.0)); - // Check that the error values at the MLE point μ. - AlgebraicDecisionTree errorTree = hbn->errorTree(cv); - DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - // regression - EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9); - EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9); - DiscreteConditional expected_m1(m1, "0.5/0.5"); DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 4735c1657..3a26f4486 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -994,16 +994,9 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) { cv.insert(X(0), Vector1(0.0)); cv.insert(X(1), Vector1(0.0)); - // Check that the error values at the MLE point μ. - AlgebraicDecisionTree errorTree = hbn->errorTree(cv); - DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - // regression - EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9); - EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9); - DiscreteConditional expected_m1(m1, "0.5/0.5"); DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());