All tests for tiny work

release/4.3a0
Frank Dellaert 2024-09-30 01:11:18 -07:00
parent 50809001e1
commit 78b47770c0
2 changed files with 41 additions and 37 deletions

View File

@ -14,6 +14,7 @@
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
* @author Frank Dellaert
* @date Mar 12, 2022 * @date Mar 12, 2022
*/ */

View File

@ -73,10 +73,6 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
GaussianBayesNet empty; GaussianBayesNet empty;
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9)); EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
// logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// evaluate // evaluate
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9); EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
@ -127,16 +123,28 @@ TEST(HybridBayesNet, Tiny) {
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}}; HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}};
// choose // Check Invariants for components
HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid(); HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid();
GaussianBayesNet chosen; GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()),
chosen.push_back(hgc->choose(zero.discrete())); gc1 = hgc->choose(one.discrete());
chosen.push_back(bayesNet.at(1)->asGaussian()); GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian();
EXPECT(assert_equal(chosen, bayesNet.choose(zero.discrete()), 1e-9)); GaussianConditional::CheckInvariants(*gc0, vv);
GaussianConditional::CheckInvariants(*gc1, vv);
GaussianConditional::CheckInvariants(*px, vv);
HybridGaussianConditional::CheckInvariants(*hgc, zero);
HybridGaussianConditional::CheckInvariants(*hgc, one);
// choose
GaussianBayesNet expectedChosen;
expectedChosen.push_back(gc0);
expectedChosen.push_back(px);
auto chosen0 = bayesNet.choose(zero.discrete());
auto chosen1 = bayesNet.choose(one.discrete());
EXPECT(assert_equal(expectedChosen, chosen0, 1e-9));
// logProbability // logProbability
const double logP0 = chosen.logProbability(vv) + log(0.4); // 0.4 is prior const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior
const double logP1 = bayesNet.choose(one.discrete()).logProbability(vv) + log(0.6); // 0.6 is prior const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior
EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9); EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9); EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9);
@ -145,7 +153,7 @@ TEST(HybridBayesNet, Tiny) {
// optimize // optimize
EXPECT(assert_equal(one, bayesNet.optimize())); EXPECT(assert_equal(one, bayesNet.optimize()));
EXPECT(assert_equal(chosen.optimize(), bayesNet.optimize(zero.discrete()))); EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
// sample // sample
std::mt19937_64 rng(42); std::mt19937_64 rng(42);
@ -156,38 +164,33 @@ TEST(HybridBayesNet, Tiny) {
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents()); EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
EXPECT(!pruned.equals(bayesNet)); EXPECT(!pruned.equals(bayesNet));
// // error // error
// EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9); const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
// EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9); px->negLogConstant() - log(0.4);
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
px->negLogConstant() - log(0.6);
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
// logDiscretePosteriorPrime, TODO: useless as -errorTree? // logDiscretePosteriorPrime, TODO: useless as -errorTree?
AlgebraicDecisionTree<Key> expected(M(0), logP0, logP1); AlgebraicDecisionTree<Key> expected(M(0), logP0, logP1);
EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv))); EXPECT(assert_equal(expected, bayesNet.logDiscretePosteriorPrime(vv)));
// // logProbability // discretePosterior
// EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9); double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1;
// EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9); AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum);
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv)));
// // discretePosterior // toFactorGraph
// AlgebraicDecisionTree<Key> expectedPosterior({Asia}, auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
// std::vector<double>{0.4, EXPECT_LONGS_EQUAL(3, fg.size());
// 0.6});
// EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
// // toFactorGraph // Check that the ratio of probPrime to evaluate is the same for all modes.
// HybridGaussianFactorGraph expectedFG{}; std::vector<double> ratio(2);
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
// auto fg = bayesNet.toFactorGraph(vv); ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
// EXPECT_LONGS_EQUAL(3, fg.size()); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
// EXPECT(assert_equal(expectedFG, fg));
// // Check that the ratio of probPrime to evaluate is the same for all modes.
// std::vector<double> ratio(2);
// ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
// ratio[0] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
// EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
// TODO: better test: check if discretePosteriors agree !
} }
/* ****************************************************************************/ /* ****************************************************************************/