New DiscreteBayesTree tests

release/4.3a0
Frank Dellaert 2025-01-22 12:05:27 -05:00
parent 8327812d29
commit b2fbbbb26f
2 changed files with 66 additions and 0 deletions

View File

@ -369,6 +369,41 @@ TEST(DiscreteBayesTree, Lookup) {
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}
/* ************************************************************************* */
// Test creating a Bayes tree directly from cliques
TEST(DiscreteBayesTree, DirectFromCliques) {
// Create a BayesNet
DiscreteBayesNet bayesNet;
DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2);
bayesNet.add(key0 % "1/3");
bayesNet.add(key1 | key0 = "1/3 3/1");
bayesNet.add(key2 | key1 = "3/1 3/1");
// Create cliques directly
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(key2 | key1 = "3/1 3/1"));
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(key1 | key0 = "1/3 3/1"));
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
std::make_shared<DiscreteConditional>(key0 % "1/3"));
// Create a BayesTree
DiscreteBayesTree bayesTree;
bayesTree.insertRoot(clique2);
bayesTree.addClique(clique1, clique2);
bayesTree.addClique(clique0, clique1);
// Check that the BayesTree is correct
DiscreteValues values;
values[0] = 1;
values[1] = 1;
values[2] = 1;
double expected = bayesNet.evaluate(values);
double actual = bayesTree.evaluate(values);
DOUBLES_EQUAL(expected, actual, 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
values[X(3)] = 2
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
def test_direct_from_cliques(self):
"""Test creating a Bayes tree directly from cliques."""
# Create a BayesNet
bayesNet = DiscreteBayesNet()
key0, key1, key2 = (0, 2), (1, 2), (2, 2)
bayesNet.add(key0, "1/3")
bayesNet.add(key1, [key0], "1/3 3/1")
bayesNet.add(key2, [key1], "3/1 3/1")
# Create cliques directly
clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1"))
clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1"))
clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3"))
# Create a BayesTree
bayesTree = gtsam.DiscreteBayesTree()
bayesTree.insertRoot(clique2)
bayesTree.addClique(clique1, clique2)
bayesTree.addClique(clique0, clique1)
# Check that the BayesTree is correct
values = DiscreteValues()
values[0] = 1
values[1] = 1
values[2] = 1
expected = bayesNet.evaluate(values)
actual = bayesTree.evaluate(values)
self.assertAlmostEqual(expected, actual, places=9)
if __name__ == "__main__":
unittest.main()