Much more comprehensive tests
parent
12349b9201
commit
44fb786b7a
|
@ -62,32 +62,117 @@ TEST(HybridBayesNet, Add) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
||||
// Test API for a pure discrete Bayes net P(Asia).
|
||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6");
|
||||
HybridValues values;
|
||||
values.insert(asiaKey, 0);
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
|
||||
const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
|
||||
bayesNet.push_back(pAsia);
|
||||
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
|
||||
|
||||
// choose
|
||||
GaussianBayesNet empty;
|
||||
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
|
||||
|
||||
// evaluate
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
|
||||
|
||||
// optimize
|
||||
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||
EXPECT(assert_equal(VectorValues{}, bayesNet.optimize(one.discrete())));
|
||||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
|
||||
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||
|
||||
// error
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||
|
||||
// logProbability
|
||||
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// toFactorGraph
|
||||
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
|
||||
EXPECT(assert_equal(expectedFG, fg));
|
||||
|
||||
// prune, imperative :-(
|
||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test creation of a tiny hybrid Bayes net.
|
||||
TEST(HybridBayesNet, Tiny) {
|
||||
auto bn = tiny::createHybridBayesNet();
|
||||
EXPECT_LONGS_EQUAL(3, bn.size());
|
||||
auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode)
|
||||
EXPECT_LONGS_EQUAL(3, bayesNet.size());
|
||||
|
||||
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
|
||||
auto fg = bn.toFactorGraph(vv);
|
||||
HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}};
|
||||
|
||||
// Check Invariants for components
|
||||
HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid();
|
||||
GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()),
|
||||
gc1 = hgc->choose(one.discrete());
|
||||
GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian();
|
||||
GaussianConditional::CheckInvariants(*gc0, vv);
|
||||
GaussianConditional::CheckInvariants(*gc1, vv);
|
||||
GaussianConditional::CheckInvariants(*px, vv);
|
||||
HybridGaussianConditional::CheckInvariants(*hgc, zero);
|
||||
HybridGaussianConditional::CheckInvariants(*hgc, one);
|
||||
|
||||
// choose
|
||||
GaussianBayesNet expectedChosen;
|
||||
expectedChosen.push_back(gc0);
|
||||
expectedChosen.push_back(px);
|
||||
auto chosen0 = bayesNet.choose(zero.discrete());
|
||||
auto chosen1 = bayesNet.choose(one.discrete());
|
||||
EXPECT(assert_equal(expectedChosen, chosen0, 1e-9));
|
||||
|
||||
// logProbability
|
||||
const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior
|
||||
const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior
|
||||
EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// evaluate
|
||||
EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9);
|
||||
|
||||
// optimize
|
||||
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||
EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
|
||||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
|
||||
|
||||
// error
|
||||
const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
|
||||
px->negLogConstant() - log(0.4);
|
||||
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
||||
px->negLogConstant() - log(0.6);
|
||||
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
||||
|
||||
// toFactorGraph
|
||||
auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
|
||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||
|
||||
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||
std::vector<double> ratio(2);
|
||||
for (size_t mode : {0, 1}) {
|
||||
const HybridValues hv{vv, {{M(0), mode}}};
|
||||
ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv);
|
||||
}
|
||||
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
|
||||
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||
|
||||
// prune, imperative :-(
|
||||
auto pruned = bayesNet.prune(1);
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
|
||||
EXPECT(!pruned.equals(bayesNet));
|
||||
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -223,12 +308,15 @@ TEST(HybridBayesNet, Optimize) {
|
|||
/* ****************************************************************************/
|
||||
// Test Bayes net error
|
||||
TEST(HybridBayesNet, Pruning) {
|
||||
// Create switching network with three continuous variables and two discrete:
|
||||
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||
Switching s(3);
|
||||
|
||||
HybridBayesNet::shared_ptr posterior =
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||
|
||||
// Optimize
|
||||
HybridValues delta = posterior->optimize();
|
||||
auto actualTree = posterior->evaluate(delta.continuous());
|
||||
|
||||
|
@ -254,7 +342,6 @@ TEST(HybridBayesNet, Pruning) {
|
|||
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
||||
logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues);
|
||||
logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues);
|
||||
// NOTE(dellaert): the discrete errors were not added in logProbability tree!
|
||||
logProbability +=
|
||||
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
||||
logProbability +=
|
||||
|
@ -316,10 +403,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
#endif
|
||||
|
||||
// regression
|
||||
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
DecisionTreeFactor::ADT potentials(
|
||||
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||
s.modes, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
|
||||
|
||||
// Prune!
|
||||
posterior->prune(maxNrLeaves);
|
||||
|
|
Loading…
Reference in New Issue