diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h index 06e2ac137..2e84485dd 100644 --- a/gtsam/discrete/Assignment.h +++ b/gtsam/discrete/Assignment.h @@ -29,7 +29,7 @@ namespace gtsam { * Assigns to each label a value. Implemented as a simple map. * A discrete factor takes an Assignment and returns a value. */ - template + template class Assignment: public std::map { public: void print(const std::string& s = "Assignment: ") const { @@ -42,6 +42,6 @@ namespace gtsam { bool equals(const Assignment& other, double tol = 1e-9) const { return (*this == other); } - }; + }; //Assignment } // namespace gtsam diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index 99176e58c..fdc3398da 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -44,8 +44,6 @@ namespace gtsam { typedef JunctionTree DiscreteJT; GenericMultifrontalSolver solver(graph); bayesTree_ = *solver.eliminate(&EliminateDiscrete); - //bayesTree_.print(); - cout << "Discrete Marginal Constructor Finished" << endl; } /** Compute the marginal of a single variable */ diff --git a/gtsam/discrete/tests/testDiscreteMarginals.cpp b/gtsam/discrete/tests/testDiscreteMarginals.cpp index 1b5963864..ad35e131e 100644 --- a/gtsam/discrete/tests/testDiscreteMarginals.cpp +++ b/gtsam/discrete/tests/testDiscreteMarginals.cpp @@ -51,6 +51,49 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) { Vector actualCvector = marginals.marginalProbabilities(Cathy); EXPECT(assert_equal(Vector_(2, 0.359631, 0.640369), actualCvector, 1e-6)); + + actualCvector = marginals.marginalProbabilities(Mark); + EXPECT(assert_equal(Vector_(2, 0.48628, 0.51372), actualCvector, 1e-6)); +} + +TEST_UNSAFE( DiscreteMarginals, UGM_chain ) { + + const int nrNodes = 10; + const size_t nrStates = 7; + + // define variables + vector nodes; + for (int i = 0; i < nrNodes; i++) { + DiscreteKey dk(i, nrStates); + nodes.push_back(dk); + } + + // create graph + DiscreteFactorGraph graph; + + // add node potentials + graph.add(nodes[0], ".3 .6 .1 0 0 0 0"); + for (int i = 1; i < nrNodes; i++) + graph.add(nodes[i], "1 1 1 1 1 1 1"); + + const std::string edgePotential = ".08 .9 .01 0 0 0 .01 " + ".03 .95 .01 0 0 0 .01 " + ".06 .06 .75 .05 .05 .02 .01 " + "0 0 0 .3 .6 .09 .01 " + "0 0 0 .02 .95 .02 .01 " + "0 0 0 .01 .01 .97 .01 " + "0 0 0 0 0 0 1"; + + // add edge potentials + for (int i = 0; i < nrNodes - 1; i++) + graph.add(nodes[i] & nodes[i + 1], edgePotential); + + DiscreteMarginals marginals(graph); + DiscreteFactor::shared_ptr actualC = marginals(nodes[2].first); + DiscreteFactor::Values values; + + values[nodes[2].first] = 0; + EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4); } /* ************************************************************************* */