New DiscreteBayesTree tests
parent
8327812d29
commit
b2fbbbb26f
|
@ -369,6 +369,41 @@ TEST(DiscreteBayesTree, Lookup) {
|
||||||
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test creating a Bayes tree directly from cliques
|
||||||
|
TEST(DiscreteBayesTree, DirectFromCliques) {
|
||||||
|
// Create a BayesNet
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2);
|
||||||
|
bayesNet.add(key0 % "1/3");
|
||||||
|
bayesNet.add(key1 | key0 = "1/3 3/1");
|
||||||
|
bayesNet.add(key2 | key1 = "3/1 3/1");
|
||||||
|
|
||||||
|
// Create cliques directly
|
||||||
|
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||||
|
std::make_shared<DiscreteConditional>(key2 | key1 = "3/1 3/1"));
|
||||||
|
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||||
|
std::make_shared<DiscreteConditional>(key1 | key0 = "1/3 3/1"));
|
||||||
|
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||||
|
std::make_shared<DiscreteConditional>(key0 % "1/3"));
|
||||||
|
|
||||||
|
// Create a BayesTree
|
||||||
|
DiscreteBayesTree bayesTree;
|
||||||
|
bayesTree.insertRoot(clique2);
|
||||||
|
bayesTree.addClique(clique1, clique2);
|
||||||
|
bayesTree.addClique(clique0, clique1);
|
||||||
|
|
||||||
|
// Check that the BayesTree is correct
|
||||||
|
DiscreteValues values;
|
||||||
|
values[0] = 1;
|
||||||
|
values[1] = 1;
|
||||||
|
values[2] = 1;
|
||||||
|
|
||||||
|
double expected = bayesNet.evaluate(values);
|
||||||
|
double actual = bayesTree.evaluate(values);
|
||||||
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
values[X(3)] = 2
|
values[X(3)] = 2
|
||||||
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
||||||
|
|
||||||
|
def test_direct_from_cliques(self):
|
||||||
|
"""Test creating a Bayes tree directly from cliques."""
|
||||||
|
# Create a BayesNet
|
||||||
|
bayesNet = DiscreteBayesNet()
|
||||||
|
key0, key1, key2 = (0, 2), (1, 2), (2, 2)
|
||||||
|
bayesNet.add(key0, "1/3")
|
||||||
|
bayesNet.add(key1, [key0], "1/3 3/1")
|
||||||
|
bayesNet.add(key2, [key1], "3/1 3/1")
|
||||||
|
|
||||||
|
# Create cliques directly
|
||||||
|
clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1"))
|
||||||
|
clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1"))
|
||||||
|
clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3"))
|
||||||
|
|
||||||
|
# Create a BayesTree
|
||||||
|
bayesTree = gtsam.DiscreteBayesTree()
|
||||||
|
bayesTree.insertRoot(clique2)
|
||||||
|
bayesTree.addClique(clique1, clique2)
|
||||||
|
bayesTree.addClique(clique0, clique1)
|
||||||
|
|
||||||
|
# Check that the BayesTree is correct
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[0] = 1
|
||||||
|
values[1] = 1
|
||||||
|
values[2] = 1
|
||||||
|
|
||||||
|
expected = bayesNet.evaluate(values)
|
||||||
|
actual = bayesTree.evaluate(values)
|
||||||
|
self.assertAlmostEqual(expected, actual, places=9)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue