more discrete BT wrapper + test
parent
10f30e1ca9
commit
3d979763f0
|
@ -211,6 +211,11 @@ class DiscreteBayesTreeClique {
|
||||||
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||||
const gtsam::DiscreteConditional* conditional() const;
|
const gtsam::DiscreteConditional* conditional() const;
|
||||||
bool isRoot() const;
|
bool isRoot() const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
|
||||||
|
void print(string s = "DiscreteBayesTreeClique",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
const string& s = "Clique: ",
|
const string& s = "Clique: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -140,9 +140,15 @@ namespace gtsam {
|
||||||
/** Access the conditional */
|
/** Access the conditional */
|
||||||
const sharedConditional& conditional() const { return conditional_; }
|
const sharedConditional& conditional() const { return conditional_; }
|
||||||
|
|
||||||
/** is this the root of a Bayes tree ? */
|
/// Is this the root of a Bayes tree?
|
||||||
inline bool isRoot() const { return parent_.expired(); }
|
inline bool isRoot() const { return parent_.expired(); }
|
||||||
|
|
||||||
|
/// Return the number of children.
|
||||||
|
size_t nrChildren() const { return children.size(); }
|
||||||
|
|
||||||
|
/// Return the child at index i.
|
||||||
|
const derived_ptr operator[](size_t i) const { return children[i]; }
|
||||||
|
|
||||||
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
||||||
size_t treeSize() const;
|
size_t treeSize() const;
|
||||||
|
|
||||||
|
|
|
@ -125,29 +125,36 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
||||||
|
|
||||||
# Check that the lookup table is correct
|
# Check that the lookup table is correct
|
||||||
assert len(lookup) == 2
|
assert lookup.size() == 2
|
||||||
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
||||||
assert len(lookup_x1_a1_x2.frontals()) == 3
|
assert lookup_x1_a1_x2.nrFrontals() == 3
|
||||||
# Check that sum is 100
|
# Check that sum is 100
|
||||||
empty = gtsam.DiscreteValues()
|
empty = gtsam.DiscreteValues()
|
||||||
assert np.isclose(lookup_x1_a1_x2.sum(3)(empty), 100, atol=1e-9)
|
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 100)
|
||||||
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||||
assert np.isclose(lookup_x1_a1_x2({X(1): 0, A(1): 1, X(2): 1}), 100, atol=1e-9)
|
values = DiscreteValues()
|
||||||
|
values[X(1)] = 0
|
||||||
|
values[A(1)] = 1
|
||||||
|
values[X(2)] = 1
|
||||||
|
self.assertAlmostEqual(lookup_x1_a1_x2(values), 100)
|
||||||
|
|
||||||
lookup_a2_x3 = lookup[X(3)].conditional()
|
lookup_a2_x3 = lookup[X(3)].conditional()
|
||||||
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
||||||
sum_x2 = lookup_a2_x3.sum(2)
|
sum_x2 = lookup_a2_x3.sum(2)
|
||||||
assert np.isclose(sum_x2({X(2): 0}), 0, atol=1e-9)
|
values = DiscreteValues()
|
||||||
assert np.isclose(sum_x2({X(2): 1}), 10, atol=1e-9)
|
values[X(2)] = 0
|
||||||
assert np.isclose(sum_x2({X(2): 2}), 20, atol=1e-9)
|
self.assertAlmostEqual(sum_x2(values), 0)
|
||||||
assert len(lookup_a2_x3.frontals()) == 2
|
values[X(2)] = 1
|
||||||
# And that the non-zero rewards are for
|
self.assertAlmostEqual(sum_x2(values), 10)
|
||||||
# x2 a2 x3 == 1 1 2
|
values[X(2)] = 2
|
||||||
assert np.isclose(lookup_a2_x3({X(2): 1, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
self.assertAlmostEqual(sum_x2(values), 20)
|
||||||
# x2 a2 x3 == 2 0 2
|
assert lookup_a2_x3.nrFrontals() == 2
|
||||||
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 0, X(3): 2}), 10, atol=1e-9)
|
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
|
||||||
# x2 a2 x3 == 2 1 2
|
values = DiscreteValues()
|
||||||
assert np.isclose(lookup_a2_x3({X(2): 2, A(2): 1, X(3): 2}), 10, atol=1e-9)
|
values[X(2)] = 1
|
||||||
|
values[A(2)] = 1
|
||||||
|
values[X(3)] = 2
|
||||||
|
self.assertAlmostEqual(lookup_a2_x3(values), 10)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue