more discrete BT wrapper + test
parent
10f30e1ca9
commit
3d979763f0
|
@ -211,6 +211,11 @@ class DiscreteBayesTreeClique {
|
|||
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||
const gtsam::DiscreteConditional* conditional() const;
|
||||
bool isRoot() const;
|
||||
size_t nrChildren() const;
|
||||
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
|
||||
void print(string s = "DiscreteBayesTreeClique",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
void printSignature(
|
||||
const string& s = "Clique: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -140,9 +140,15 @@ namespace gtsam {
|
|||
/** Access the conditional */
|
||||
const sharedConditional& conditional() const { return conditional_; }
|
||||
|
||||
/** is this the root of a Bayes tree ? */
|
||||
/// Is this the root of a Bayes tree?
|
||||
inline bool isRoot() const { return parent_.expired(); }
|
||||
|
||||
/// Return the number of children.
|
||||
size_t nrChildren() const { return children.size(); }
|
||||
|
||||
/// Return the child at index i.
|
||||
const derived_ptr operator[](size_t i) const { return children[i]; }
|
||||
|
||||
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
||||
size_t treeSize() const;
|
||||
|
||||
|
|
|
@ -125,29 +125,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
||||
|
||||
# Check that the lookup table is correct
|
||||
assert len(lookup) == 2
|
||||
assert lookup.size() == 2
|
||||
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
||||
assert len(lookup_x1_a1_x2.frontals()) == 3
|
||||
assert lookup_x1_a1_x2.nrFrontals() == 3
|
||||
# Check that sum is 100
|
||||
empty = gtsam.DiscreteValues()
|
||||
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
|
||||
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 100)
|
||||
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
|
||||
values = DiscreteValues()
|
||||
values[X(1)] = 0
|
||||
values[A(1)] = 1
|
||||
values[X(2)] = 1
|
||||
self.assertAlmostEqual(lookup_x1_a1_x2(values), 100)
|
||||
|
||||
lookup_a2_x3 = lookup[X(3)].conditional()
|
||||
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
||||
sum_x2 = lookup_a2_x3.sum(2)
|
||||
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
|
||||
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
|
||||
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
|
||||
assert len(lookup_a2_x3.frontals()) == 2
|
||||
# And that the non-zero rewards are for
|
||||
# x2 a2 x3 == 1 1 2
|
||||
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
||||
# x2 a2 x3 == 2 0 2
|
||||
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
|
||||
# x2 a2 x3 == 2 1 2
|
||||
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
||||
values = DiscreteValues()
|
||||
values[X(2)] = 0
|
||||
self.assertAlmostEqual(sum_x2(values), 0)
|
||||
values[X(2)] = 1
|
||||
self.assertAlmostEqual(sum_x2(values), 10)
|
||||
values[X(2)] = 2
|
||||
self.assertAlmostEqual(sum_x2(values), 20)
|
||||
assert lookup_a2_x3.nrFrontals() == 2
|
||||
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
|
||||
values = DiscreteValues()
|
||||
values[X(2)] = 1
|
||||
values[A(2)] = 1
|
||||
values[X(3)] = 2
|
||||
self.assertAlmostEqual(lookup_a2_x3(values), 10)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue