New test for DiscreteMarginals

release/4.3a0
Abhijit Kundu 2012-06-07 22:14:03 +00:00
parent b2f2db4fdd
commit 22b1002703
3 changed files with 45 additions and 4 deletions

View File

@ -29,7 +29,7 @@ namespace gtsam {
* Assigns to each label a value. Implemented as a simple map. * Assigns to each label a value. Implemented as a simple map.
* A discrete factor takes an Assignment and returns a value. * A discrete factor takes an Assignment and returns a value.
*/ */
template <class L> template<class L>
class Assignment: public std::map<L, size_t> { class Assignment: public std::map<L, size_t> {
public: public:
void print(const std::string& s = "Assignment: ") const { void print(const std::string& s = "Assignment: ") const {
@ -42,6 +42,6 @@ namespace gtsam {
bool equals(const Assignment& other, double tol = 1e-9) const { bool equals(const Assignment& other, double tol = 1e-9) const {
return (*this == other); return (*this == other);
} }
}; }; //Assignment
} // namespace gtsam } // namespace gtsam

View File

@ -44,8 +44,6 @@ namespace gtsam {
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT; typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph); GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
bayesTree_ = *solver.eliminate(&EliminateDiscrete); bayesTree_ = *solver.eliminate(&EliminateDiscrete);
//bayesTree_.print();
cout << "Discrete Marginal Constructor Finished" << endl;
} }
/** Compute the marginal of a single variable */ /** Compute the marginal of a single variable */

View File

@ -51,6 +51,49 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
Vector actualCvector = marginals.marginalProbabilities(Cathy); Vector actualCvector = marginals.marginalProbabilities(Cathy);
EXPECT(assert_equal(Vector_(2, 0.359631, 0.640369), actualCvector, 1e-6)); EXPECT(assert_equal(Vector_(2, 0.359631, 0.640369), actualCvector, 1e-6));
actualCvector = marginals.marginalProbabilities(Mark);
EXPECT(assert_equal(Vector_(2, 0.48628, 0.51372), actualCvector, 1e-6));
}
TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
const int nrNodes = 10;
const size_t nrStates = 7;
// define variables
vector<DiscreteKey> nodes;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey dk(i, nrStates);
nodes.push_back(dk);
}
// create graph
DiscreteFactorGraph graph;
// add node potentials
graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
for (int i = 1; i < nrNodes; i++)
graph.add(nodes[i], "1 1 1 1 1 1 1");
const std::string edgePotential = ".08 .9 .01 0 0 0 .01 "
".03 .95 .01 0 0 0 .01 "
".06 .06 .75 .05 .05 .02 .01 "
"0 0 0 .3 .6 .09 .01 "
"0 0 0 .02 .95 .02 .01 "
"0 0 0 .01 .01 .97 .01 "
"0 0 0 0 0 0 1";
// add edge potentials
for (int i = 0; i < nrNodes - 1; i++)
graph.add(nodes[i] & nodes[i + 1], edgePotential);
DiscreteMarginals marginals(graph);
DiscreteFactor::shared_ptr actualC = marginals(nodes[2].first);
DiscreteFactor::Values values;
values[nodes[2].first] = 0;
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
} }
/* ************************************************************************* */ /* ************************************************************************* */