update testGaussianMixture
parent
462a5b8b3a
commit
5e1931eb98
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
|
|
@ -79,8 +80,9 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
double midway = mu1 - mu0;
|
double midway = mu1 - mu0;
|
||||||
auto eliminationResult =
|
auto eliminationResult =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
||||||
auto pMid = *eliminationResult->at(0)->asDiscrete();
|
auto pMid = std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||||
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
|
eliminationResult->at(0)->asDiscrete());
|
||||||
|
EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid));
|
||||||
|
|
||||||
// Everywhere else, the result should be a sigmoid.
|
// Everywhere else, the result should be a sigmoid.
|
||||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||||
|
|
@ -90,7 +92,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
// Workflow 1: convert HBN to HFG and solve
|
// Workflow 1: convert HBN to HFG and solve
|
||||||
auto eliminationResult1 =
|
auto eliminationResult1 =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
auto posterior1 = *std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||||
|
eliminationResult1->at(0)->asDiscrete());
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
|
@ -99,7 +102,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
||||||
hfg1.push_back(mixing);
|
hfg1.push_back(mixing);
|
||||||
auto eliminationResult2 = hfg1.eliminateSequential();
|
auto eliminationResult2 = hfg1.eliminateSequential();
|
||||||
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
auto posterior2 = *std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||||
|
eliminationResult2->at(0)->asDiscrete());
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -133,13 +137,14 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
// Eliminate the graph!
|
// Eliminate the graph!
|
||||||
auto eliminationResultMax = gfg.eliminateSequential();
|
auto eliminationResultMax = gfg.eliminateSequential();
|
||||||
|
|
||||||
// Equality of posteriors asserts that the elimination is correct (same ratios
|
// Equality of posteriors asserts that the elimination is correct
|
||||||
// for all modes)
|
// (same ratios for all modes)
|
||||||
EXPECT(assert_equal(expectedDiscretePosterior,
|
EXPECT(assert_equal(expectedDiscretePosterior,
|
||||||
eliminationResultMax->discretePosterior(vv)));
|
eliminationResultMax->discretePosterior(vv)));
|
||||||
|
|
||||||
auto pMax = *eliminationResultMax->at(0)->asDiscrete();
|
auto pMax = *std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||||
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
|
eliminationResultMax->at(0)->asDiscrete());
|
||||||
|
EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4));
|
||||||
|
|
||||||
// Everywhere else, the result should be a bell curve like function.
|
// Everywhere else, the result should be a bell curve like function.
|
||||||
for (const double shift : {-4, -2, 0, 2, 4}) {
|
for (const double shift : {-4, -2, 0, 2, 4}) {
|
||||||
|
|
@ -149,7 +154,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
// Workflow 1: convert HBN to HFG and solve
|
// Workflow 1: convert HBN to HFG and solve
|
||||||
auto eliminationResult1 =
|
auto eliminationResult1 =
|
||||||
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
auto posterior1 = *std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||||
|
eliminationResult1->at(0)->asDiscrete());
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
|
@ -158,7 +164,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
||||||
hfg.push_back(mixing);
|
hfg.push_back(mixing);
|
||||||
auto eliminationResult2 = hfg.eliminateSequential();
|
auto eliminationResult2 = hfg.eliminateSequential();
|
||||||
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
auto posterior2 = *std::dynamic_pointer_cast<DiscreteTableConditional>(
|
||||||
|
eliminationResult2->at(0)->asDiscrete());
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue