New test for DiscreteMarginals
parent
b2f2db4fdd
commit
22b1002703
|
@ -29,7 +29,7 @@ namespace gtsam {
|
|||
* Assigns to each label a value. Implemented as a simple map.
|
||||
* A discrete factor takes an Assignment and returns a value.
|
||||
*/
|
||||
template <class L>
|
||||
template<class L>
|
||||
class Assignment: public std::map<L, size_t> {
|
||||
public:
|
||||
void print(const std::string& s = "Assignment: ") const {
|
||||
|
@ -42,6 +42,6 @@ namespace gtsam {
|
|||
bool equals(const Assignment& other, double tol = 1e-9) const {
|
||||
return (*this == other);
|
||||
}
|
||||
};
|
||||
}; //Assignment
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -44,8 +44,6 @@ namespace gtsam {
|
|||
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
|
||||
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
|
||||
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
|
||||
//bayesTree_.print();
|
||||
cout << "Discrete Marginal Constructor Finished" << endl;
|
||||
}
|
||||
|
||||
/** Compute the marginal of a single variable */
|
||||
|
|
|
@ -51,6 +51,49 @@ TEST_UNSAFE( DiscreteMarginals, UGM_small ) {
|
|||
|
||||
Vector actualCvector = marginals.marginalProbabilities(Cathy);
|
||||
EXPECT(assert_equal(Vector_(2, 0.359631, 0.640369), actualCvector, 1e-6));
|
||||
|
||||
actualCvector = marginals.marginalProbabilities(Mark);
|
||||
EXPECT(assert_equal(Vector_(2, 0.48628, 0.51372), actualCvector, 1e-6));
|
||||
}
|
||||
|
||||
TEST_UNSAFE( DiscreteMarginals, UGM_chain ) {
|
||||
|
||||
const int nrNodes = 10;
|
||||
const size_t nrStates = 7;
|
||||
|
||||
// define variables
|
||||
vector<DiscreteKey> nodes;
|
||||
for (int i = 0; i < nrNodes; i++) {
|
||||
DiscreteKey dk(i, nrStates);
|
||||
nodes.push_back(dk);
|
||||
}
|
||||
|
||||
// create graph
|
||||
DiscreteFactorGraph graph;
|
||||
|
||||
// add node potentials
|
||||
graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
|
||||
for (int i = 1; i < nrNodes; i++)
|
||||
graph.add(nodes[i], "1 1 1 1 1 1 1");
|
||||
|
||||
const std::string edgePotential = ".08 .9 .01 0 0 0 .01 "
|
||||
".03 .95 .01 0 0 0 .01 "
|
||||
".06 .06 .75 .05 .05 .02 .01 "
|
||||
"0 0 0 .3 .6 .09 .01 "
|
||||
"0 0 0 .02 .95 .02 .01 "
|
||||
"0 0 0 .01 .01 .97 .01 "
|
||||
"0 0 0 0 0 0 1";
|
||||
|
||||
// add edge potentials
|
||||
for (int i = 0; i < nrNodes - 1; i++)
|
||||
graph.add(nodes[i] & nodes[i + 1], edgePotential);
|
||||
|
||||
DiscreteMarginals marginals(graph);
|
||||
DiscreteFactor::shared_ptr actualC = marginals(nodes[2].first);
|
||||
DiscreteFactor::Values values;
|
||||
|
||||
values[nodes[2].first] = 0;
|
||||
EXPECT_DOUBLES_EQUAL( 0.03426, (*actualC)(values), 1e-4);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue