Fix confusion between parents and frontals
parent
8666a15f2e
commit
9c5bba753c
|
@ -90,6 +90,22 @@ public:
|
||||||
/// GTSAM-style equals
|
/// GTSAM-style equals
|
||||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Iterator pointing to first frontal key. */
|
||||||
|
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
|
||||||
|
|
||||||
|
/** Iterator pointing past the last frontal key. */
|
||||||
|
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
|
||||||
|
|
||||||
|
/** Iterator pointing to the first parent key. */
|
||||||
|
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
|
||||||
|
|
||||||
|
/** Iterator pointing past the last parent key. */
|
||||||
|
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -18,8 +18,10 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||||
#include <gtsam/base/debug.h>
|
#include <gtsam/base/debug.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/Vector.h>
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
@ -29,10 +31,42 @@
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteBayesNet, bayesNet) {
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||||
|
|
||||||
|
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||||
|
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
|
||||||
|
(Potentials::ADT)*prior));
|
||||||
|
bayesNet.push_back(prior);
|
||||||
|
|
||||||
|
auto conditional =
|
||||||
|
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
||||||
|
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
||||||
|
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||||
|
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
|
||||||
|
bayesNet.push_back(conditional);
|
||||||
|
|
||||||
|
DiscreteFactorGraph fg(bayesNet);
|
||||||
|
LONGS_EQUAL(2, fg.back()->size());
|
||||||
|
|
||||||
|
// Check the marginals
|
||||||
|
const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2};
|
||||||
|
DiscreteMarginals marginals(fg);
|
||||||
|
for (size_t j = 0; j < 2; j++) {
|
||||||
|
Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2));
|
||||||
|
EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3);
|
||||||
|
EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesNet, Asia)
|
TEST(DiscreteBayesNet, Asia)
|
||||||
{
|
{
|
||||||
|
|
|
@ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors)
|
||||||
DiscreteConditional::shared_ptr expected1 = //
|
DiscreteConditional::shared_ptr expected1 = //
|
||||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
||||||
EXPECT(expected1);
|
EXPECT(expected1);
|
||||||
|
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||||
|
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||||
|
EXPECT(expected1->endParents() == expected1->beginFrontals());
|
||||||
|
EXPECT(expected1->endFrontals() == expected1->end());
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional actual1(1, f1);
|
||||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
||||||
|
|
Loading…
Reference in New Issue