New DiscreteBayesTree tests
parent
8327812d29
commit
b2fbbbb26f
|
@ -369,6 +369,41 @@ TEST(DiscreteBayesTree, Lookup) {
|
|||
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test creating a Bayes tree directly from cliques
|
||||
TEST(DiscreteBayesTree, DirectFromCliques) {
|
||||
// Create a BayesNet
|
||||
DiscreteBayesNet bayesNet;
|
||||
DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2);
|
||||
bayesNet.add(key0 % "1/3");
|
||||
bayesNet.add(key1 | key0 = "1/3 3/1");
|
||||
bayesNet.add(key2 | key1 = "3/1 3/1");
|
||||
|
||||
// Create cliques directly
|
||||
auto clique2 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(key2 | key1 = "3/1 3/1"));
|
||||
auto clique1 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(key1 | key0 = "1/3 3/1"));
|
||||
auto clique0 = std::make_shared<DiscreteBayesTree::Clique>(
|
||||
std::make_shared<DiscreteConditional>(key0 % "1/3"));
|
||||
|
||||
// Create a BayesTree
|
||||
DiscreteBayesTree bayesTree;
|
||||
bayesTree.insertRoot(clique2);
|
||||
bayesTree.addClique(clique1, clique2);
|
||||
bayesTree.addClique(clique0, clique1);
|
||||
|
||||
// Check that the BayesTree is correct
|
||||
DiscreteValues values;
|
||||
values[0] = 1;
|
||||
values[1] = 1;
|
||||
values[2] = 1;
|
||||
|
||||
double expected = bayesNet.evaluate(values);
|
||||
double actual = bayesTree.evaluate(values);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -156,5 +156,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
values[X(3)] = 2
|
||||
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
|
||||
|
||||
def test_direct_from_cliques(self):
|
||||
"""Test creating a Bayes tree directly from cliques."""
|
||||
# Create a BayesNet
|
||||
bayesNet = DiscreteBayesNet()
|
||||
key0, key1, key2 = (0, 2), (1, 2), (2, 2)
|
||||
bayesNet.add(key0, "1/3")
|
||||
bayesNet.add(key1, [key0], "1/3 3/1")
|
||||
bayesNet.add(key2, [key1], "3/1 3/1")
|
||||
|
||||
# Create cliques directly
|
||||
clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1"))
|
||||
clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1"))
|
||||
clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3"))
|
||||
|
||||
# Create a BayesTree
|
||||
bayesTree = gtsam.DiscreteBayesTree()
|
||||
bayesTree.insertRoot(clique2)
|
||||
bayesTree.addClique(clique1, clique2)
|
||||
bayesTree.addClique(clique0, clique1)
|
||||
|
||||
# Check that the BayesTree is correct
|
||||
values = DiscreteValues()
|
||||
values[0] = 1
|
||||
values[1] = 1
|
||||
values[2] = 1
|
||||
|
||||
expected = bayesNet.evaluate(values)
|
||||
actual = bayesTree.evaluate(values)
|
||||
self.assertAlmostEqual(expected, actual, places=9)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue