New test for DiscreteMarginals
parent
b2f2db4fdd
commit
22b1002703
|
@ -29,7 +29,7 @@ namespace gtsam {
|
||||||
* Assigns to each label a value. Implemented as a simple map.
|
* Assigns to each label a value. Implemented as a simple map.
|
||||||
* A discrete factor takes an Assignment and returns a value.
|
* A discrete factor takes an Assignment and returns a value.
|
||||||
*/
|
*/
|
||||||
template <class L>
|
template<class L>
|
||||||
class Assignment: public std::map<L, size_t> {
|
class Assignment: public std::map<L, size_t> {
|
||||||
public:
|
public:
|
||||||
void print(const std::string& s = "Assignment: ") const {
|
void print(const std::string& s = "Assignment: ") const {
|
||||||
|
@ -42,6 +42,6 @@ namespace gtsam {
|
||||||
bool equals(const Assignment& other, double tol = 1e-9) const {
|
bool equals(const Assignment& other, double tol = 1e-9) const {
|
||||||
return (*this == other);
|
return (*this == other);
|
||||||
}
|
}
|
||||||
};
|
}; //Assignment
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -44,8 +44,6 @@ namespace gtsam {
|
||||||
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
|
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
|
||||||
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
|
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
|
||||||
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
|
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
|
||||||
//bayesTree_.print();
|
|
||||||
cout << "Discrete Marginal Constructor Finished" << endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Compute the marginal of a single variable */
|
/** Compute the marginal of a single variable */
|
||||||
|
|
|
@ -51,6 +51,49 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
|
||||||
|
|
||||||
Vector actualCvector = marginals.marginalProbabilities(Cathy);
|
Vector actualCvector = marginals.marginalProbabilities(Cathy);
|
||||||
EXPECT(assert_equal(Vector_(2, 0.359631, 0.640369), actualCvector, 1e-6));
|
EXPECT(assert_equal(Vector_(2, 0.359631, 0.640369), actualCvector, 1e-6));
|
||||||
|
|
||||||
|
actualCvector = marginals.marginalProbabilities(Mark);
|
||||||
|
EXPECT(assert_equal(Vector_(2, 0.48628, 0.51372), actualCvector, 1e-6));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
|
||||||
|
|
||||||
|
const int nrNodes = 10;
|
||||||
|
const size_t nrStates = 7;
|
||||||
|
|
||||||
|
// define variables
|
||||||
|
vector<DiscreteKey> nodes;
|
||||||
|
for (int i = 0; i < nrNodes; i++) {
|
||||||
|
DiscreteKey dk(i, nrStates);
|
||||||
|
nodes.push_back(dk);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create graph
|
||||||
|
DiscreteFactorGraph graph;
|
||||||
|
|
||||||
|
// add node potentials
|
||||||
|
graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
|
||||||
|
for (int i = 1; i < nrNodes; i++)
|
||||||
|
graph.add(nodes[i], "1 1 1 1 1 1 1");
|
||||||
|
|
||||||
|
const std::string edgePotential = ".08 .9 .01 0 0 0 .01 "
|
||||||
|
".03 .95 .01 0 0 0 .01 "
|
||||||
|
".06 .06 .75 .05 .05 .02 .01 "
|
||||||
|
"0 0 0 .3 .6 .09 .01 "
|
||||||
|
"0 0 0 .02 .95 .02 .01 "
|
||||||
|
"0 0 0 .01 .01 .97 .01 "
|
||||||
|
"0 0 0 0 0 0 1";
|
||||||
|
|
||||||
|
// add edge potentials
|
||||||
|
for (int i = 0; i < nrNodes - 1; i++)
|
||||||
|
graph.add(nodes[i] & nodes[i + 1], edgePotential);
|
||||||
|
|
||||||
|
DiscreteMarginals marginals(graph);
|
||||||
|
DiscreteFactor::shared_ptr actualC = marginals(nodes[2].first);
|
||||||
|
DiscreteFactor::Values values;
|
||||||
|
|
||||||
|
values[nodes[2].first] = 0;
|
||||||
|
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue