multifrontal lookup table tests
parent
9857e62a56
commit
0a24a8ac43
|
@ -80,7 +80,7 @@ namespace gtsam {
|
|||
* @endcode
|
||||
*
|
||||
* The values in the table should be laid out so that the first key varies
|
||||
* the slowest. and the last key the fastest.
|
||||
* the slowest, and the last key the fastest.
|
||||
*/
|
||||
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||
const std::vector<double>& table);
|
||||
|
@ -101,7 +101,7 @@ namespace gtsam {
|
|||
* @endcode
|
||||
*
|
||||
* The values in the table should be laid out so that the first key varies
|
||||
* the slowest. and the last key the fastest.
|
||||
* the slowest, and the last key the fastest.
|
||||
*/
|
||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||
|
||||
|
|
|
@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
|
|||
|
||||
//** evaluate conditional probability of subtree for given DiscreteValues */
|
||||
double evaluate(const DiscreteValues& values) const;
|
||||
|
||||
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||
double operator()(const DiscreteValues& values) const {
|
||||
return evaluate(values);
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -215,6 +215,7 @@ class DiscreteBayesTreeClique {
|
|||
const string& s = "Clique: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
};
|
||||
|
||||
class DiscreteBayesTree {
|
||||
|
@ -229,6 +230,7 @@ class DiscreteBayesTree {
|
|||
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||
|
||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
|
||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/Vector.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
|
@ -26,7 +27,6 @@
|
|||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
static constexpr bool debug = false;
|
||||
|
||||
|
@ -108,7 +108,7 @@ TEST(DiscreteBayesTree, ThinTree) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of separator marginals
|
||||
TEST(DiscreteBayesTree, separatorMarginal) {
|
||||
TEST(DiscreteBayesTree, SeparatorMarginals) {
|
||||
TestFixture self;
|
||||
|
||||
// Calculate some marginals for DiscreteValues==all1
|
||||
|
@ -141,7 +141,7 @@ TEST(DiscreteBayesTree, separatorMarginal) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check shortcuts in the tree
|
||||
TEST(DiscreteBayesTree, shortcut) {
|
||||
TEST(DiscreteBayesTree, Shortcuts) {
|
||||
TestFixture self;
|
||||
|
||||
// Calculate some marginals for DiscreteValues==all1
|
||||
|
@ -199,7 +199,7 @@ TEST(DiscreteBayesTree, shortcut) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check all marginals
|
||||
TEST(DiscreteBayesTree, marginalFactor) {
|
||||
TEST(DiscreteBayesTree, MarginalFactors) {
|
||||
TestFixture self;
|
||||
|
||||
Vector marginals = Vector::Zero(15);
|
||||
|
@ -286,7 +286,7 @@ TEST(DiscreteBayesTree, Joints) {
|
|||
/* ************************************************************************* */
|
||||
TEST(DiscreteBayesTree, Dot) {
|
||||
TestFixture self;
|
||||
string actual = self.bayesTree->dot();
|
||||
std::string actual = self.bayesTree->dot();
|
||||
EXPECT(actual ==
|
||||
"digraph G{\n"
|
||||
"0[label=\"13, 11, 6, 7\"];\n"
|
||||
|
@ -313,6 +313,61 @@ TEST(DiscreteBayesTree, Dot) {
|
|||
"}");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check that we can have a multi-frontal lookup table
|
||||
TEST(DiscreteBayesTree, Lookup) {
|
||||
using gtsam::symbol_shorthand::A;
|
||||
using gtsam::symbol_shorthand::X;
|
||||
|
||||
// Make a small planning-like graph: 3 states, 2 actions
|
||||
DiscreteFactorGraph graph;
|
||||
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
|
||||
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
|
||||
const DiscreteKeys keys{x1, x2, x3, a1, a2};
|
||||
// Constraint on start and goal
|
||||
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
|
||||
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
|
||||
// Should I stay or should I go?
|
||||
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||
const double r = 10;
|
||||
std::vector<double> table{
|
||||
r, 0, 0, 0, r, 0, // x1 = 0
|
||||
0, r, 0, 0, 0, r, // x1 = 1
|
||||
0, 0, r, 0, 0, r // x1 = 2
|
||||
};
|
||||
graph.add(DiscreteKeys{x1, a1, x2}, table);
|
||||
graph.add(DiscreteKeys{x2, a2, x3}, table);
|
||||
|
||||
// eliminate for MPE (maximum probable explanation).
|
||||
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
|
||||
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
|
||||
|
||||
// Check that the lookup table is correct
|
||||
EXPECT_LONGS_EQUAL(2, lookup->size());
|
||||
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
|
||||
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
|
||||
// check that sum is 100
|
||||
DiscreteValues empty;
|
||||
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
|
||||
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
|
||||
|
||||
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
|
||||
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
|
||||
auto sum_x2 = lookup_a2_x3->sum(2);
|
||||
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(10, (*sum_x2)({{X(2),1}}), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(20, (*sum_x2)({{X(2),2}}), 1e-9);
|
||||
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
|
||||
// And that the non-zero rewards are for
|
||||
// x2 a2 x3 == 1 1 2
|
||||
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
|
||||
// x2 a2 x3 == 2 0 2
|
||||
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
|
||||
// x2 a2 x3 == 2 1 2
|
||||
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue