more discrete BT wrapper + test

release/4.3a0
Frank Dellaert 2023-06-19 07:38:19 -07:00
parent 10f30e1ca9
commit 3d979763f0
3 changed files with 34 additions and 16 deletions

View File

@ -211,6 +211,11 @@ class DiscreteBayesTreeClique {
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
const gtsam::DiscreteConditional* conditional() const;
bool isRoot() const;
size_t nrChildren() const;
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
void print(string s = "DiscreteBayesTreeClique",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printSignature(
const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;

View File

@ -140,9 +140,15 @@ namespace gtsam {
/** Access the conditional */
const sharedConditional& conditional() const { return conditional_; }
/** is this the root of a Bayes tree ? */
/// Is this the root of a Bayes tree?
inline bool isRoot() const { return parent_.expired(); }
/// Return the number of children.
size_t nrChildren() const { return children.size(); }
/// Return the child at index i.
const derived_ptr operator[](size_t i) const { return children[i]; }
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
size_t treeSize() const;

View File

@ -125,29 +125,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
# Check that the lookup table is correct
assert len(lookup) == 2
assert lookup.size() == 2
lookup_x1_a1_x2 = lookup[X(1)].conditional()
assert len(lookup_x1_a1_x2.frontals()) == 3
assert lookup_x1_a1_x2.nrFrontals() == 3
# Check that sum is 100
empty = gtsam.DiscreteValues()
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 100)
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
values = DiscreteValues()
values[X(1)] = 0
values[A(1)] = 1
values[X(2)] = 1
self.assertAlmostEqual(lookup_x1_a1_x2(values), 100)
lookup_a2_x3 = lookup[X(3)].conditional()
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
sum_x2 = lookup_a2_x3.sum(2)
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
assert len(lookup_a2_x3.frontals()) == 2
# And that the non-zero rewards are for
# x2 a2 x3 == 1 1 2
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
# x2 a2 x3 == 2 0 2
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
# x2 a2 x3 == 2 1 2
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
values = DiscreteValues()
values[X(2)] = 0
self.assertAlmostEqual(sum_x2(values), 0)
values[X(2)] = 1
self.assertAlmostEqual(sum_x2(values), 10)
values[X(2)] = 2
self.assertAlmostEqual(sum_x2(values), 20)
assert lookup_a2_x3.nrFrontals() == 2
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
values = DiscreteValues()
values[X(2)] = 1
values[A(2)] = 1
values[X(3)] = 2
self.assertAlmostEqual(lookup_a2_x3(values), 10)
if __name__ == "__main__":
unittest.main()