Fix confusion between parents and frontals
parent
8666a15f2e
commit
9c5bba753c
|
@ -90,6 +90,22 @@ public:
|
|||
/// GTSAM-style equals
|
||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
|
||||
/// @{
|
||||
|
||||
/** Iterator pointing to first frontal key. */
|
||||
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
|
||||
|
||||
/** Iterator pointing past the last frontal key. */
|
||||
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
|
||||
|
||||
/** Iterator pointing to the first parent key. */
|
||||
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
|
||||
|
||||
/** Iterator pointing past the last parent key. */
|
||||
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
|
@ -29,10 +31,42 @@
|
|||
using namespace boost::assign;
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, bayesNet) {
|
||||
DiscreteBayesNet bayesNet;
|
||||
DiscreteKey Parent(0, 2), Child(1, 2);
|
||||
|
||||
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
|
||||
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
|
||||
(Potentials::ADT)*prior));
|
||||
bayesNet.push_back(prior);
|
||||
|
||||
auto conditional =
|
||||
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
|
||||
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
|
||||
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
|
||||
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
|
||||
bayesNet.push_back(conditional);
|
||||
|
||||
DiscreteFactorGraph fg(bayesNet);
|
||||
LONGS_EQUAL(2, fg.back()->size());
|
||||
|
||||
// Check the marginals
|
||||
const double expectedMarginal[2]{0.4, 0.6 * 0.3 + 0.4 * 0.2};
|
||||
DiscreteMarginals marginals(fg);
|
||||
for (size_t j = 0; j < 2; j++) {
|
||||
Vector FT = marginals.marginalProbabilities(DiscreteKey(j, 2));
|
||||
EXPECT_DOUBLES_EQUAL(expectedMarginal[j], FT[1], 1e-3);
|
||||
EXPECT_DOUBLES_EQUAL(FT[0], 1.0 - FT[1], 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesNet, Asia)
|
||||
{
|
||||
|
|
|
@ -36,6 +36,11 @@ TEST( DiscreteConditional, constructors)
|
|||
DiscreteConditional::shared_ptr expected1 = //
|
||||
boost::make_shared<DiscreteConditional>(X | Y = "1/1 2/3 1/4");
|
||||
EXPECT(expected1);
|
||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||
EXPECT(expected1->endParents() == expected1->beginFrontals());
|
||||
EXPECT(expected1->endFrontals() == expected1->end());
|
||||
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional actual1(1, f1);
|
||||
EXPECT(assert_equal(*expected1, actual1, 1e-9));
|
||||
|
|
Loading…
Reference in New Issue