multifrontal lookup table tests

release/4.3a0
Frank Dellaert 2023-06-10 12:34:30 -07:00
parent 9857e62a56
commit 0a24a8ac43
4 changed files with 69 additions and 7 deletions

View File

@ -80,7 +80,7 @@ namespace gtsam {
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest. and the last key the fastest.
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
@ -101,7 +101,7 @@ namespace gtsam {
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest. and the last key the fastest.
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);

View File

@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
//** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
};
/* ************************************************************************* */

View File

@ -215,6 +215,7 @@ class DiscreteBayesTreeClique {
const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
};
class DiscreteBayesTree {
@ -229,6 +230,7 @@ class DiscreteBayesTree {
const DiscreteBayesTreeClique* operator[](size_t j) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

View File

@ -16,6 +16,7 @@
*/
#include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
@ -26,7 +27,6 @@
#include <iostream>
#include <vector>
using namespace std;
using namespace gtsam;
static constexpr bool debug = false;
@ -108,7 +108,7 @@ TEST(DiscreteBayesTree, ThinTree) {
/* ************************************************************************* */
// Check calculation of separator marginals
TEST(DiscreteBayesTree, separatorMarginal) {
TEST(DiscreteBayesTree, SeparatorMarginals) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
@ -141,7 +141,7 @@ TEST(DiscreteBayesTree, separatorMarginal) {
/* ************************************************************************* */
// Check shortcuts in the tree
TEST(DiscreteBayesTree, shortcut) {
TEST(DiscreteBayesTree, Shortcuts) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
@ -199,7 +199,7 @@ TEST(DiscreteBayesTree, shortcut) {
/* ************************************************************************* */
// Check all marginals
TEST(DiscreteBayesTree, marginalFactor) {
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;
Vector marginals = Vector::Zero(15);
@ -286,7 +286,7 @@ TEST(DiscreteBayesTree, Joints) {
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
TestFixture self;
string actual = self.bayesTree->dot();
std::string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n"
@ -313,6 +313,61 @@ TEST(DiscreteBayesTree, Dot) {
"}");
}
/* ************************************************************************* */
// Check that we can have a multi-frontal lookup table
TEST(DiscreteBayesTree, Lookup) {
using gtsam::symbol_shorthand::A;
using gtsam::symbol_shorthand::X;
// Make a small planning-like graph: 3 states, 2 actions
DiscreteFactorGraph graph;
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
const DiscreteKeys keys{x1, x2, x3, a1, a2};
// Constraint on start and goal
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
// Should I stay or should I go?
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
const double r = 10;
std::vector<double> table{
r, 0, 0, 0, r, 0, // x1 = 0
0, r, 0, 0, 0, r, // x1 = 1
0, 0, r, 0, 0, r // x1 = 2
};
graph.add(DiscreteKeys{x1, a1, x2}, table);
graph.add(DiscreteKeys{x2, a2, x3}, table);
// eliminate for MPE (maximum probable explanation).
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
// Check that the lookup table is correct
EXPECT_LONGS_EQUAL(2, lookup->size());
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
// check that sum is 100
DiscreteValues empty;
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
auto sum_x2 = lookup_a2_x3->sum(2);
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
EXPECT_DOUBLES_EQUAL(10, (*sum_x2)({{X(2),1}}), 1e-9);
EXPECT_DOUBLES_EQUAL(20, (*sum_x2)({{X(2),2}}), 1e-9);
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
// And that the non-zero rewards are for
// x2 a2 x3 == 1 1 2
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 0 2
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 1 2
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;