Fix confusion between parents and frontals

release/4.3a0
Frank dellaert 2020-07-12 12:28:43 -04:00
parent 8666a15f2e
commit 9c5bba753c
3 changed files with 56 additions and 1 deletions

View File

@ -90,6 +90,22 @@ public:
/// GTSAM-style equals
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
/// @}
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
/// @{
/** Iterator pointing to first frontal key. */
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
/** Iterator pointing past the last frontal key. */
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
/** Iterator pointing to the first parent key. */
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
/** Iterator pointing past the last parent key. */
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
/// @}
/// @name Standard Interface
/// @{

View File

@ -18,8 +18,10 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/base/debug.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>
#include <CppUnitLite/TestHarness.h>
@ -29,10 +31,42 @@
using namespace boost::assign;
#include <iostream>
#include <string>
#include <vector>
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2);
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
(Potentials::ADT)*prior));
bayesNet.push_back(prior);
auto conditional =
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
bayesNet.push_back(conditional);
DiscreteFactorGraph fg(bayesNet);
LONGS_EQUAL(2, fg.back()->size());
// Check the marginals
const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2};
DiscreteMarginals marginals(fg);
for (size_t j = 0; j < 2; j++) {
Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2));
EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3);
EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9);
}
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, Asia)
{

View File

@ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors)
DiscreteConditional::shared_ptr expected1 = //
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
EXPECT(expected1);
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
EXPECT(expected1->endParents() == expected1->beginFrontals());
EXPECT(expected1->endFrontals() == expected1->end());
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1);
EXPECT(assert_equal(*expected1, actual1, 1e-9));