diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 225e6e1d3..b1e9da754 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -90,6 +90,22 @@ public: /// GTSAM-style equals bool equals(const DiscreteFactor& other, double tol = 1e-9) const; + /// @} + /// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger: + /// @{ + + /** Iterator pointing to first frontal key. */ + typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); } + + /** Iterator pointing past the last frontal key. */ + typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); } + + /** Iterator pointing to the first parent key. */ + typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); } + + /** Iterator pointing past the last parent key. */ + typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; } + /// @} /// @name Standard Interface /// @{ diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 5ed662332..c3f8aacf1 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -18,8 +18,10 @@ #include #include -#include +#include #include +#include +#include #include @@ -29,10 +31,42 @@ using namespace boost::assign; #include +#include +#include using namespace std; using namespace gtsam; +/* ************************************************************************* */ +TEST(DiscreteBayesNet, bayesNet) { + DiscreteBayesNet bayesNet; + DiscreteKey Parent(0, 2), Child(1, 2); + + auto prior = boost::make_shared(Parent % "6/4"); + CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"), + (Potentials::ADT)*prior)); + bayesNet.push_back(prior); + + auto conditional = + boost::make_shared(Child | Parent = "7/3 8/2"); + EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals())); + Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2"); + CHECK(assert_equal(expected, (Potentials::ADT)*conditional)); + bayesNet.push_back(conditional); + + DiscreteFactorGraph fg(bayesNet); + LONGS_EQUAL(2, fg.back()->size()); + + // Check the marginals + const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2}; + DiscreteMarginals marginals(fg); + for (size_t j = 0; j < 2; j++) { + Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2)); + EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3); + EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9); + } +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 888bf76df..577edecb3 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors) DiscreteConditional::shared_ptr expected1 = // boost::make_shared(X | Y = "1/1 2/3 1/4"); EXPECT(expected1); + EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); + EXPECT(expected1->endParents() == expected1->beginFrontals()); + EXPECT(expected1->endFrontals() == expected1->end()); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional actual1(1, f1); EXPECT(assert_equal(*expected1, actual1, 1e-9));