From 3d979763f0797af41f250622a99db630af14139d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 19 Jun 2023 07:38:19 -0700 Subject: [PATCH] more discrete BT wrapper + test --- gtsam/discrete/discrete.i | 5 +++ gtsam/inference/BayesTreeCliqueBase.h | 8 ++++- python/gtsam/tests/test_DiscreteBayesTree.py | 37 ++++++++++++-------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b66069340..fe8cbc7f3 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -211,6 +211,11 @@ class DiscreteBayesTreeClique { DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); const gtsam::DiscreteConditional* conditional() const; bool isRoot() const; + size_t nrChildren() const; + const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const; + void print(string s = "DiscreteBayesTreeClique", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; void printSignature( const string& s = "Clique: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index c9bb6db94..c60d6a8c8 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -140,9 +140,15 @@ namespace gtsam { /** Access the conditional */ const sharedConditional& conditional() const { return conditional_; } - /** is this the root of a Bayes tree ? */ + /// Is this the root of a Bayes tree? inline bool isRoot() const { return parent_.expired(); } + /// Return the number of children. + size_t nrChildren() const { return children.size(); } + + /// Return the child at index i. + const derived_ptr operator[](size_t i) const { return children[i]; } + /** The size of subtree rooted at this clique, i.e., nr of Cliques */ size_t treeSize() const; diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index 720f884e0..09fb5a865 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -125,29 +125,36 @@ class TestDiscreteBayesNet(GtsamTestCase): lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE) # Check that the lookup table is correct - assert len(lookup) == 2 + assert lookup.size() == 2 lookup_x1_a1_x2 = lookup[X(1)].conditional() - assert len(lookup_x1_a1_x2.frontals()) == 3 + assert lookup_x1_a1_x2.nrFrontals() == 3 # Check that sum is 100 empty = gtsam.DiscreteValues() - assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9) + self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 100) # And that only non-zero reward is for x1 a1 x2 == 0 1 1 - assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9) + values = DiscreteValues() + values[X(1)] = 0 + values[A(1)] = 1 + values[X(2)] = 1 + self.assertAlmostEqual(lookup_x1_a1_x2(values), 100) lookup_a2_x3 = lookup[X(3)].conditional() # Check that the sum depends on x2 and is non-zero only for x2 in {1, 2} sum_x2 = lookup_a2_x3.sum(2) - assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9) - assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9) - assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9) - assert len(lookup_a2_x3.frontals()) == 2 - # And that the non-zero rewards are for - # x2 a2 x3 == 1 1 2 - assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9) - # x2 a2 x3 == 2 0 2 - assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9) - # x2 a2 x3 == 2 1 2 - assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9) + values = DiscreteValues() + values[X(2)] = 0 + self.assertAlmostEqual(sum_x2(values), 0) + values[X(2)] = 1 + self.assertAlmostEqual(sum_x2(values), 10) + values[X(2)] = 2 + self.assertAlmostEqual(sum_x2(values), 20) + assert lookup_a2_x3.nrFrontals() == 2 + # And that the non-zero rewards are for x2 a2 x3 == 1 1 2 + values = DiscreteValues() + values[X(2)] = 1 + values[A(2)] = 1 + values[X(3)] = 2 + self.assertAlmostEqual(lookup_a2_x3(values), 10) if __name__ == "__main__": unittest.main()