Merge pull request #1619 from borglab/release/4.2

release/4.3a0
Frank Dellaert 2023-09-03 19:21:18 -07:00 committed by GitHub
commit 4f66a491ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 922 additions and 226 deletions

View File

@ -28,9 +28,9 @@
namespace gtsam {
/**
* Algebraic Decision Trees fix the range to double
* Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether?
* An algebraic decision tree fixes the range of a DecisionTree to double.
* Just has some nice constructors and some syntactic sugar.
* TODO(dellaert): consider eliminating this class altogether?
*
* @ingroup discrete
*/
@ -80,20 +80,62 @@ namespace gtsam {
AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {}
/** Create a new leaf function splitting on a variable */
/**
* @brief Create a new leaf function splitting on a variable
*
* @param labelC: The label with cardinality 2
* @param y1: The value for the first key
* @param y2: The value for the second key
*
* Example:
* @code{.cpp}
* std::pair<string, size_t> A {"a", 2};
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
* @endcode
*/
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2)
: Base(labelC, y1, y2) {}
/** Create from keys and vector table */
/**
* @brief Create from keys with cardinalities and a vector table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param ys: The vector table
*
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
* respectively, and a vector table of size 12:
* @code{.cpp}
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
* const vector<double> cpt{
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
* @endcode
* The table is given in the following order:
* A=0, B=0, C=0
* A=0, B=0, C=1
* ...
* A=1, B=1, C=1
* Hence, the first line in the table is for A==0, and the second for A==1.
* In each line, the first two entries are for B==0, the next two for B==1,
* and the last two for B==2. Each pair is for a C value of 0 and 1.
*/
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) {
const std::vector<double>& ys) {
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/** Create from keys and string table */
/**
* @brief Create from keys and string table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param table: The string table, given as a string of doubles.
*
* @note Table needs to be in same order as the vector table in the other constructor.
*/
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
@ -108,7 +150,13 @@ namespace gtsam {
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}
/** Create a new function splitting on a variable */
/**
* @brief Create a range of decision trees, splitting on a single variable.
*
* @param begin: Iterator to beginning of a range of decision trees
* @param end: Iterator to end of a range of decision trees
* @param label: The label to split on
*/
template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) {

View File

@ -622,7 +622,7 @@ namespace gtsam {
// B=1
// A=0: 3
// A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
// exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y>

View File

@ -37,9 +37,23 @@
namespace gtsam {
/**
* Decision Tree
* L = label for variables
* Y = function range (any algebra), e.g., bool, int, double
* @brief a decision tree is a function from assignments to values.
* @tparam L label for variables
* @tparam Y function range (any algebra), e.g., bool, int, double
*
* After creating a decision tree on some variables, the tree can be evaluated
* on an assignment to those variables. Example:
*
* @code{.cpp}
* // Create a decision stump one one variable 'a' with values 10 and 20.
* DecisionTree<char, int> tree('a', 10, 20);
*
* // Evaluate the tree on an assignment to the variable.
* int value0 = tree({{'a', 0}}); // value0 = 10
* int value1 = tree({{'a', 1}}); // value1 = 20
* @endcode
*
* More examples can be found in testDecisionTree.cpp
*
* @ingroup discrete
*/
@ -132,7 +146,8 @@ namespace gtsam {
NodePtr root_;
protected:
/** Internal recursive function to create from keys, cardinalities,
/**
* Internal recursive function to create from keys, cardinalities,
* and Y values
*/
template<typename It, typename ValueIt>
@ -163,7 +178,13 @@ namespace gtsam {
/** Create a constant */
explicit DecisionTree(const Y& y);
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
/**
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
*
* @param label The variable to split on.
* @param y1 The value for the first assignment.
* @param y2 The value for the second assignment.
*/
DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */

View File

@ -63,11 +63,46 @@ namespace gtsam {
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */
/**
* @brief Constructor from doubles
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
* DecisionTreeFactor f1({X, Y}, table);
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
const std::vector<double>& table);
/** Constructor from string */
/**
* @brief Constructor from string
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization

View File

@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
//** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
};
/* ************************************************************************* */

View File

@ -42,16 +42,30 @@ class DiscreteJunctionTree;
/**
* @brief Main elimination function for DiscreteFactorGraph.
*
* @param factors
* @param keys
* @return GTSAM_EXPORT
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph>
{
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
}
};
/* ************************************************************************* */
/**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Implicit copy/downcast constructor to override explicit template container
* constructor */
template <class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
/// Destructor
virtual ~DiscreteFactorGraph() {}
@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @}
}; // \ DiscreteFactorGraph
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/// traits
template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

View File

@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
}

View File

@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
/// @}
};
/// Free version of CartesianProduct.
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
return DiscreteValues::CartesianProduct(keys);
}
/// Free version of markdown.
std::string markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,

View File

@ -17,6 +17,8 @@ class DiscreteKeys {
};
// DiscreteValues is added in specializations/discrete.h as a std::map
std::vector<gtsam::DiscreteValues> cartesianProduct(
const gtsam::DiscreteKeys& keys);
string markdown(
const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
@ -31,27 +33,30 @@ string html(const gtsam::DiscreteValues& values,
std::map<gtsam::Key, std::vector<std::string>> names);
#include <gtsam/discrete/DiscreteFactor.h>
class DiscreteFactor {
virtual class DiscreteFactor : gtsam::Factor {
void print(string s = "DiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
double operator()(const gtsam::DiscreteValues& values) const;
};
#include <gtsam/discrete/DecisionTreeFactor.h>
virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor();
DecisionTreeFactor(const gtsam::DiscreteKey& key,
const std::vector<double>& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys,
const std::vector<double>& table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys,
const std::vector<double>& table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n",
@ -59,6 +64,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
size_t cardinality(gtsam::Key j) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const;
@ -66,6 +73,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* max(const gtsam::Ordering& keys) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
@ -203,10 +211,16 @@ class DiscreteBayesTreeClique {
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
const gtsam::DiscreteConditional* conditional() const;
bool isRoot() const;
size_t nrChildren() const;
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
void print(string s = "DiscreteBayesTreeClique",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printSignature(
const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
};
class DiscreteBayesTree {
@ -220,6 +234,9 @@ class DiscreteBayesTree {
bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s,
@ -242,9 +259,9 @@ class DiscreteBayesTree {
class DiscreteLookupTable : gtsam::DiscreteConditional{
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor::ADT& potentials);
void print(
const std::string& s = "Discrete Lookup Table: ",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const;
void print(string s = "Discrete Lookup Table: ",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
};
@ -263,6 +280,14 @@ class DiscreteLookupDAG {
};
#include <gtsam/discrete/DiscreteFactorGraph.h>
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
class DiscreteFactorGraph {
DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
@ -277,6 +302,7 @@ class DiscreteFactorGraph {
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
bool empty() const;
size_t size() const;
@ -290,25 +316,46 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct();
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet sumProduct(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
gtsam::DiscreteLookupDAG maxProduct(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential();
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering);
eliminatePartialSequential(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal();
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
@ -328,4 +375,41 @@ class DiscreteFactorGraph {
std::map<gtsam::Key, std::vector<std::string>> names) const;
};
#include <gtsam/discrete/DiscreteEliminationTree.h>
class DiscreteEliminationTree {
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/discrete/DiscreteJunctionTree.h>
class DiscreteCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::DiscreteFactorGraph factors;
const gtsam::DiscreteCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class DiscreteJunctionTree {
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::DiscreteCluster& operator[](size_t i) const;
};
} // namespace gtsam

View File

@ -71,6 +71,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ************************************************************************** */
// Test char labels and int range
/* ************************************************************************** */
// Create a decision stump one one variable 'a' with values 10 and 20.
TEST(DecisionTree, Constructor) {
DecisionTree<char, int> tree('a', 10, 20);
// Evaluate the tree on an assignment to the variable.
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
}
/* ************************************************************************** */
// Test string labels and int range
/* ************************************************************************** */
@ -114,18 +127,47 @@ struct Ring {
static inline int mul(const int& a, const int& b) { return a * b; }
};
/* ************************************************************************** */
// Check that creating decision trees respects key order.
TEST(DecisionTree, ConstructorOrder) {
// Create labels
string A("A"), B("B");
const std::vector<int> ys1 = {1, 2, 3, 4};
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
const std::vector<int> ys2 = {1, 3, 2, 4};
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
// Both trees will be the same, tree is order from high to low labels.
// Choice(B)
// 0 Choice(A)
// 0 0 Leaf 1
// 0 1 Leaf 2
// 1 Choice(A)
// 1 0 Leaf 3
// 1 1 Leaf 4
EXPECT(tree2.equals(tree1));
// Check the values are as expected by calling the () operator:
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
}
/* ************************************************************************** */
// test DT
TEST(DecisionTree, example) {
TEST(DecisionTree, Example) {
// Create labels
string A("A"), B("B"), C("C");
// create a value
Assignment<string> x00, x01, x10, x11;
x00[A] = 0, x00[B] = 0;
x01[A] = 0, x01[B] = 1;
x10[A] = 1, x10[B] = 0;
x11[A] = 1, x11[B] = 1;
// Create assignments using brace initialization:
Assignment<string> x00{{A, 0}, {B, 0}};
Assignment<string> x01{{A, 0}, {B, 1}};
Assignment<string> x10{{A, 1}, {B, 0}};
Assignment<string> x11{{A, 1}, {B, 1}};
// empty
DT empty;
@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
StringBoolTree f2(f1, bool_of_int);
// Check a value
Assignment<string> x00;
x00["A"] = 0, x00["B"] = 0;
Assignment<string> x00 {{A, 0}, {B, 0}};
EXPECT(!f2(x00));
}
@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) {
// Check some values
Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0;
x01[X] = 0, x01[Y] = 1;
x10[X] = 1, x10[Y] = 0;
x11[X] = 1, x11[Y] = 1;
x00 = {{X, 0}, {Y, 0}};
x01 = {{X, 0}, {Y, 1}};
x10 = {{X, 1}, {Y, 0}};
x11 = {{X, 1}, {Y, 1}};
EXPECT(!f2(x00));
EXPECT(!f2(x01));
EXPECT(f2(x10));

View File

@ -27,6 +27,18 @@
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
DiscreteKey X(0, 2), Y(1, 3);
// Create with vector and with string
const std::vector<double> table {2, 5, 3, 6, 4, 7};
DecisionTreeFactor f1({X, Y}, table);
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
EXPECT(assert_equal(f1, f2));
}
/* ************************************************************************* */
TEST( DecisionTreeFactor, constructors)
{
@ -41,16 +53,13 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size());
DiscreteValues values;
values[0] = 1; // x
values[1] = 2; // y
values[2] = 1; // z
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
}
/* ************************************************************************* */

View File

@ -16,23 +16,24 @@
*/
#include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesNet.h>
#include <CppUnitLite/TestHarness.h>
#include <iostream>
#include <vector>
using namespace std;
using namespace gtsam;
static constexpr bool debug = false;
/* ************************************************************************* */
struct TestFixture {
vector<DiscreteKey> keys;
DiscreteKeys keys;
std::vector<DiscreteValues> assignments;
DiscreteBayesNet bayesNet;
boost::shared_ptr<DiscreteBayesTree> bayesTree;
@ -47,6 +48,9 @@ struct TestFixture {
keys.push_back(key_i);
}
// Enumerate all assignments.
assignments = DiscreteValues::CartesianProduct(keys);
// Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3");
@ -74,9 +78,9 @@ struct TestFixture {
};
/* ************************************************************************* */
// Check that BN and BT give the same answer on all configurations
TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self;
const auto& keys = self.keys;
TestFixture self;
if (debug) {
GTSAM_PRINT(self.bayesNet);
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
}
auto R = self.bayesTree->roots().front();
// Check whether BN and BT give the same answer on all configurations
auto allPosbValues = DiscreteValues::CartesianProduct(
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
for (const auto& x : self.assignments) {
double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9);
}
}
// Calculate all some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
/* ************************************************************************* */
// Check calculation of separator marginals
TEST(DiscreteBayesTree, SeparatorMarginals) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
double marginal_14 = 0, joint_8_12 = 0;
for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[8] && x[12]) joint_8_12 += px;
if (x[2]) {
if (x[8]) joint82 += px;
if (x[1]) joint12 += px;
}
if (x[4]) {
if (x[2]) joint24 += px;
if (x[5]) joint45 += px;
if (x[6]) joint46 += px;
if (x[11]) joint_4_11 += px;
}
if (x[14]) marginal_14 += px;
}
DiscreteValues all1 = self.assignments.back();
// check separator marginal P(S0)
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
}
/* ************************************************************************* */
// Check shortcuts in the tree
TEST(DiscreteBayesTree, Shortcuts) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x);
if (x[11] && x[13]) {
joint_11_13 += px;
if (x[8] && x[12]) joint_8_11_12_13 += px;
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
}
}
}
DiscreteValues all1 = allPosbValues.back();
DiscreteValues all1 = self.assignments.back();
// check separator marginal P(S0)
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
auto R = self.bayesTree->roots().front();
// check shortcut P(S9||R) to root
clique = (*self.bayesTree)[9];
auto clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
shortcut.print("shortcut:");
}
}
}
/* ************************************************************************* */
// Check all marginals
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
}
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
DiscreteValues all1 = self.assignments.back();
for (size_t i = 0; i < 15; i++) {
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
}
/* ************************************************************************* */
// Check a number of joint marginals.
TEST(DiscreteBayesTree, Joints) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[2]) {
if (x[8]) joint82 += px;
if (x[1]) joint12 += px;
}
if (x[4]) {
if (x[2]) joint24 += px;
if (x[5]) joint45 += px;
if (x[6]) joint46 += px;
if (x[11]) joint_4_11 += px;
}
}
// regression tests:
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
DiscreteValues all1 = self.assignments.back();
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2)
@ -240,8 +285,8 @@ TEST(DiscreteBayesTree, ThinTree) {
/* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) {
const TestFixture self;
string actual = self.bayesTree->dot();
TestFixture self;
std::string actual = self.bayesTree->dot();
EXPECT(actual ==
"digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n"
@ -268,6 +313,62 @@ TEST(DiscreteBayesTree, Dot) {
"}");
}
/* ************************************************************************* */
// Check that we can have a multi-frontal lookup table
TEST(DiscreteBayesTree, Lookup) {
using gtsam::symbol_shorthand::A;
using gtsam::symbol_shorthand::X;
// Make a small planning-like graph: 3 states, 2 actions
DiscreteFactorGraph graph;
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
// Constraint on start and goal
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
// Should I stay or should I go?
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
const double r = 10;
std::vector<double> table{
r, 0, 0, 0, r, 0, // x1 = 0
0, r, 0, 0, 0, r, // x1 = 1
0, 0, r, 0, 0, r // x1 = 2
};
graph.add(DiscreteKeys{x1, a1, x2}, table);
graph.add(DiscreteKeys{x2, a2, x3}, table);
// eliminate for MPE (maximum probable explanation).
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
// Check that the lookup table is correct
EXPECT_LONGS_EQUAL(2, lookup->size());
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
// check that sum is 100
DiscreteValues empty;
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
auto sum_x2 = lookup_a2_x3->sum(2);
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
EXPECT_DOUBLES_EQUAL(10, (*sum_x2)({{X(2),1}}), 1e-9);
EXPECT_DOUBLES_EQUAL(20, (*sum_x2)({{X(2),2}}), 1e-9);
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
// And that the non-zero rewards are for
// x2 a2 x3 == 1 1 2
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 0 2
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 1 2
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -106,7 +106,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result = addGaussian(result, gf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
result = gmf->add(result);
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
@ -283,17 +285,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// taking care to correct for conditional constant.
// Correct for the normalization constant used up by the conditional
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
auto correct = [&](const Result &pair) {
const auto &factor = pair.second;
if (!factor) return factor; // TODO(dellaert): not loving this.
if (!factor) return;
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
if (!hf) throw std::runtime_error("Expected HessianFactor!");
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
return hf;
};
eliminationResults.visit(correct);
GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
correct);
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
continuousSeparator, discreteSeparator, newFactors);

View File

@ -17,6 +17,7 @@
*/
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/hybrid/MixtureFactor.h>
@ -69,6 +70,12 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
// If discrete-only: doesn't need linearization.
linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
linearFG->push_back(gmf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
linearFG->push_back(gm);
} else if (dynamic_pointer_cast<GaussianFactor>(f)) {
linearFG->push_back(f);
} else {
auto& fr = *f;
throw std::invalid_argument(

View File

@ -23,6 +23,37 @@
namespace gtsam {
/* ************************************************************************* */
Ordering HybridSmoother::getOrdering(
const HybridGaussianFactorGraph &newFactors) {
HybridGaussianFactorGraph factors(hybridBayesNet());
factors += newFactors;
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
const KeySet newFactorKeys = newFactors.keys();
// Insert continuous keys first.
for (auto &k : newFactorKeys) {
if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k);
}
}
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));
const VariableIndex index(newFactors);
// Get an ordering where the new keys are eliminated last
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
return ordering;
}
/* ************************************************************************* */
void HybridSmoother::update(HybridGaussianFactorGraph graph,
const Ordering &ordering,
@ -92,7 +123,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
}
graph.push_back(newConditionals);
// newConditionals.print("\n\n\nNew Conditionals to add back");
}
return {graph, hybridBayesNet};
}

View File

@ -50,6 +50,8 @@ class HybridSmoother {
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
boost::optional<size_t> maxNrLeaves = boost::none);
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
/**
* @brief Add conditionals from previous timestep as part of liquefication.
*

View File

@ -35,14 +35,11 @@ class HybridValues {
};
#include <gtsam/hybrid/HybridFactor.h>
virtual class HybridFactor {
virtual class HybridFactor : gtsam::Factor {
void print(string s = "HybridFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
// Standard interface:
double error(const gtsam::HybridValues &values) const;

View File

@ -93,6 +93,7 @@ TEST(GaussianMixtureFactor, Sum) {
EXPECT(actual.at(1) == f22);
}
/* ************************************************************************* */
TEST(GaussianMixtureFactor, Printing) {
DiscreteKey m1(1, 2);
auto A1 = Matrix::Zero(2, 1);
@ -136,6 +137,7 @@ TEST(GaussianMixtureFactor, Printing) {
EXPECT(assert_print_equal(expected, mixtureFactor));
}
/* ************************************************************************* */
TEST(GaussianMixtureFactor, GaussianMixture) {
KeyVector keys;
keys.push_back(X(0));

View File

@ -612,7 +612,6 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
using symbol_shorthand::Z;
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
@ -694,7 +693,6 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
using symbol_shorthand::Z;
const int num_measurements = 1;
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements);
@ -726,11 +724,67 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
EXPECT(ratioTest(bn, measurements, *posterior));
}
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement with mode order swapped
// yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
const VectorValues measurements{{Z(0), Vector1(5.0)}};
// Create mode key: 1 is low-noise, 0 is high-noise.
const DiscreteKey mode{M(0), 2};
HybridBayesNet bn;
// Create Gaussian mixture z_0 = x0 + noise for each measurement.
bn.emplace_back(new GaussianMixture(
{Z(0)}, {X(0)}, {mode},
{GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3),
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1,
0.5)}));
// Create prior on X(0).
bn.push_back(
GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));
// Add prior on mode.
bn.emplace_back(new DiscreteConditional(mode, "1/1"));
// bn.print();
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(3, fg.size());
// fg.print();
EXPECT(ratioTest(bn, measurements, fg));
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
// Create Gaussian mixture on X(0).
// regression, but mean checked to be 5.0 in both cases:
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.1379), I_1x1 * 2.02759),
conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(14.1421), I_1x1 * 2.82843);
expectedBayesNet.emplace_back(
new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1}));
// Add prior on mode.
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "1/1"));
// Test elimination
const auto posterior = fg.eliminateSequential();
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
// posterior->print();
// posterior->optimize().print();
}
/* ****************************************************************************/
// Check that eliminating tiny net with 2 measurements yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2;
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements);
@ -764,7 +818,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Test eliminating tiny net with 1 mode per measurement.
TEST(HybridGaussianFactorGraph, EliminateTiny22) {
// Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2;
const bool manyModes = true;
@ -835,12 +888,12 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// D D
// | |
// m1 m2
// | |
// | |
// C-x0-HC-x1-HC-x2
// | | |
// HF HF HF
// | | |
// n0 n1 n2
// n0 n1 n2
// | | |
// D D D
EXPECT_LONGS_EQUAL(11, fg.size());
@ -853,7 +906,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
EXPECT(ratioTest(bn, measurements, fg1));
// Create ordering that eliminates in time order, then discrete modes:
Ordering ordering {X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)};
Ordering ordering{X(2), X(1), X(0), N(0), N(1), N(2), M(1), M(2)};
// Do elimination:
const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering);

View File

@ -140,9 +140,15 @@ namespace gtsam {
/** Access the conditional */
const sharedConditional& conditional() const { return conditional_; }
/** is this the root of a Bayes tree ? */
/// Return true if this clique is the root of a Bayes tree.
inline bool isRoot() const { return parent_.expired(); }
/// Return the number of children.
size_t nrChildren() const { return children.size(); }
/// Return the child at index i.
const derived_ptr operator[](size_t i) const { return children[i]; }
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
size_t treeSize() const;

View File

@ -49,7 +49,7 @@ class ClusterTree {
virtual ~Cluster() {}
const Cluster& operator[](size_t i) const {
return *(children[i]);
return *(children.at(i));
}
/// Construct from factors associated with a single key
@ -161,7 +161,7 @@ class ClusterTree {
}
const Cluster& operator[](size_t i) const {
return *(roots_[i]);
return *(roots_.at(i));
}
/// @}

View File

@ -52,12 +52,12 @@ namespace gtsam {
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
* expose functions for computing marginals, conditional marginals, doing multifrontal and
* sequential elimination, etc. */
template<class FACTORGRAPH>
template<class FACTOR_GRAPH>
class EliminateableFactorGraph
{
private:
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
// Base factor type stored in this graph (private because derived classes will get this from
// their FactorGraph base class)
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
@ -139,7 +139,7 @@ namespace gtsam {
OptionalVariableIndex variableIndex = boost::none) const;
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
* provided, the ordering will be computed using either COLAMD or METIS, depending on
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
*
* <b> Example - Full Cholesky elimination in COLAMD order: </b>
@ -160,7 +160,7 @@ namespace gtsam {
OptionalVariableIndex variableIndex = boost::none) const;
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
* provided, the ordering will be computed using either COLAMD or METIS, depending on
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
*
* <b> Example - Full QR elimination in specified order:

View File

@ -104,6 +104,7 @@ class Ordering {
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);
Ordering(const std::vector<size_t>& keys);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
@ -147,7 +148,7 @@ class Ordering {
// Standard interface
size_t size() const;
size_t at(size_t key) const;
size_t at(size_t i) const;
void push_back(size_t key);
// enabling serialization functionality
@ -197,4 +198,15 @@ class VariableIndex {
size_t nEntries() const;
};
#include <gtsam/inference/Factor.h>
virtual class Factor {
void print(string s = "Factor\n", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printKeys(string s = "") const;
bool equals(const gtsam::Factor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
};
} // namespace gtsam

View File

@ -261,8 +261,7 @@ class VectorValues {
};
#include <gtsam/linear/GaussianFactor.h>
virtual class GaussianFactor {
gtsam::KeyVector keys() const;
virtual class GaussianFactor : gtsam::Factor {
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
@ -273,8 +272,6 @@ virtual class GaussianFactor {
Matrix information() const;
Matrix augmentedJacobian() const;
pair<Matrix, Vector> jacobian() const;
size_t size() const;
bool empty() const;
};
#include <gtsam/linear/JacobianFactor.h>
@ -301,10 +298,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor {
//Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
gtsam::KeyVector& keys() const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
size_t size() const;
Vector unweighted_error(const gtsam::VectorValues& c) const;
Vector error_vector(const gtsam::VectorValues& c) const;
double error(const gtsam::VectorValues& c) const;
@ -346,10 +340,8 @@ virtual class HessianFactor : gtsam::GaussianFactor {
HessianFactor(const gtsam::GaussianFactorGraph& factors);
//Testable
size_t size() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
double error(const gtsam::VectorValues& c) const;

View File

@ -110,13 +110,10 @@ class NonlinearFactorGraph {
};
#include <gtsam/nonlinear/NonlinearFactor.h>
virtual class NonlinearFactor {
virtual class NonlinearFactor : gtsam::Factor {
// Factor base class
size_t size() const;
gtsam::KeyVector keys() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
// NonlinearFactor
bool equals(const gtsam::NonlinearFactor& other, double tol) const;
double error(const gtsam::Values& c) const;

View File

@ -20,6 +20,7 @@
#include <algorithm>
#include <iostream>
#include <iomanip>
namespace gtsam {
@ -38,7 +39,8 @@ static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
// Image pair is (i1,i2).
size_t i1 = pair_indices.first;
size_t i2 = pair_indices.second;
for (size_t k = 0; k < corr_indices.rows(); k++) {
size_t m = static_cast<size_t>(corr_indices.rows());
for (size_t k = 0; k < m; k++) {
// Measurement indices are found in a single matrix row, as (k1,k2).
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
// Unique key for DSF is (i,k), representing keypoint index in an image.
@ -128,7 +130,7 @@ std::vector<SfmTrack2d> tracksFromPairwiseMatches(
}
// TODO(johnwlambert): return the Transitivity failure percentage here.
return tracks2d;
return validTracks;
}
} // namespace gtsfm

View File

@ -20,9 +20,10 @@
#include <gtsam/base/DSFMap.h>
#include <gtsam/sfm/SfmTrack.h>
#include <boost/optional.hpp>
#include <Eigen/Core>
#include <map>
#include <optional>
#include <vector>
namespace gtsam {

View File

@ -65,4 +65,6 @@ namespace gtsam {
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
};
/// typedef for wrapper:
using SymbolicCluster = SymbolicJunctionTree::Cluster;
}

View File

@ -4,7 +4,7 @@
namespace gtsam {
#include <gtsam/symbolic/SymbolicFactor.h>
virtual class SymbolicFactor {
virtual class SymbolicFactor : gtsam::Factor {
// Standard Constructors and Named Constructors
SymbolicFactor(const gtsam::SymbolicFactor& f);
SymbolicFactor();
@ -18,12 +18,10 @@ virtual class SymbolicFactor {
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
// From Factor
size_t size() const;
void print(string s = "SymbolicFactor",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicFactor& other, double tol) const;
gtsam::KeyVector keys();
};
#include <gtsam/symbolic/SymbolicFactorGraph.h>
@ -139,7 +137,60 @@ class SymbolicBayesNet {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
#include <gtsam/symbolic/SymbolicEliminationTree.h>
class SymbolicEliminationTree {
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/symbolic/SymbolicJunctionTree.h>
class SymbolicCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::SymbolicFactorGraph factors;
const gtsam::SymbolicCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class SymbolicJunctionTree {
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::SymbolicCluster& operator[](size_t i) const;
};
#include <gtsam/symbolic/SymbolicBayesTree.h>
class SymbolicBayesTreeClique {
SymbolicBayesTreeClique();
SymbolicBayesTreeClique(const gtsam::SymbolicConditional* conditional);
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter);
const gtsam::SymbolicConditional* conditional() const;
bool isRoot() const;
gtsam::SymbolicBayesTreeClique* parent() const;
size_t treeSize() const;
size_t numCachedSeparatorMarginals() const;
void deleteCachedShortcuts();
};
class SymbolicBayesTree {
// Constructors
SymbolicBayesTree();
@ -151,9 +202,14 @@ class SymbolicBayesTree {
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
// Standard Interface
// size_t findParentClique(const gtsam::IndexVector& parents) const;
size_t size();
void saveGraph(string s) const;
bool empty() const;
size_t size() const;
const gtsam::SymbolicBayesTreeClique* operator[](size_t j) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void clear();
void deleteCachedShortcuts();
size_t numCachedSeparatorMarginals() const;
@ -161,28 +217,9 @@ class SymbolicBayesTree {
gtsam::SymbolicConditional* marginalFactor(size_t key) const;
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
};
class SymbolicBayesTreeClique {
SymbolicBayesTreeClique();
// SymbolicBayesTreeClique(gtsam::sharedConditional* conditional);
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
size_t numCachedSeparatorMarginals() const;
// gtsam::sharedConditional* conditional() const;
bool isRoot() const;
size_t treeSize() const;
gtsam::SymbolicBayesTreeClique* parent() const;
// // TODO: need wrapped versions graphs, BayesNet
// BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function)
// const; FactorGraph<FactorType> marginal(derived_ptr root, Eliminate
// function) const; FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr
// root, Eliminate function) const;
//
void deleteCachedShortcuts();
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
} // namespace gtsam

View File

@ -25,7 +25,13 @@ class TestDecisionTreeFactor(GtsamTestCase):
self.B = (5, 2)
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
def test_from_floats(self):
"""Test whether we can construct a factor from floats."""
actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.])
self.gtsamAssertEquals(actual, self.factor)
def test_enumerate(self):
"""Test whether we can enumerate the factor."""
actual = self.factor.enumerate()
_, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])

View File

@ -13,10 +13,15 @@ Author: Frank Dellaert
import unittest
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph, Ordering)
import numpy as np
from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph,
DiscreteValues, Ordering)
class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
@ -27,7 +32,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
# Define DiscreteKey pairs.
keys = [(j, 2) for j in range(15)]
# Create thin-tree Bayesnet.
# Create thin-tree Bayes net.
bayesNet = DiscreteBayesNet()
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
@ -65,15 +70,105 @@ class TestDiscreteBayesNet(GtsamTestCase):
# bayesTree[key].printSignature()
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
# The root is P( 8 12 14), we can retrieve it by key:
root = bayesTree[8]
self.assertIsInstance(root, DiscreteBayesTreeClique)
self.assertTrue(root.isRoot())
self.assertIsInstance(root.conditional(), DiscreteConditional)
# Test all methods in DiscreteBayesTree
self.gtsamAssertEquals(bayesTree, bayesTree)
# Check value at 0
zero_values = DiscreteValues()
for j in range(15):
zero_values[j] = 0
value_at_zeros = bayesTree.evaluate(zero_values)
self.assertAlmostEqual(value_at_zeros, 0.0)
# Check value at max
values_star = factorGraph.optimize()
max_value = bayesTree.evaluate(values_star)
self.assertAlmostEqual(max_value, 0.002548)
# Check operator sugar
max_value = bayesTree(values_star)
self.assertAlmostEqual(max_value, 0.002548)
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
@unittest.skip("TODO: segfaults on gcc 7 and gcc 9")
def test_discrete_bayes_tree_lookup(self):
"""Check that we can have a multi-frontal lookup table."""
# Make a small planning-like graph: 3 states, 2 actions
graph = DiscreteFactorGraph()
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
a1, a2 = (A(1), 2), (A(2), 2)
# Constraint on start and goal
graph.add([x1], np.array([1, 0, 0]))
graph.add([x3], np.array([0, 0, 1]))
# Should I stay or should I go?
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
r = 10
table = np.array([
r, 0, 0, 0, r, 0, # x1 = 0
0, r, 0, 0, 0, r, # x1 = 1
0, 0, r, 0, 0, r # x1 = 2
])
graph.add([x1, a1, x2], table)
graph.add([x2, a2, x3], table)
# print(graph) will give:
# size: 4
# factor 0: f[ (x1,3), ] ...
# factor 1: f[ (x3,3), ] ...
# factor 2: f[ (x1,3), (a1,2), (x2,3), ] ...
# factor 3: f[ (x2,3), (a2,2), (x3,3), ] ...
# Eliminate for MPE (maximum probable explanation).
ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
# print(lookup) will give:
# DiscreteBayesTree
# : cliques: 2, variables: 5
# - g( x1 a1 x2 ): ...
# | - g( a2 x3 ; x2 ): ...
# Check that the lookup table is correct
assert lookup.size() == 2
lookup_x1_a1_x2 = lookup[X(1)].conditional()
assert lookup_x1_a1_x2.nrFrontals() == 3
# Check that sum is 100
empty = gtsam.DiscreteValues()
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 100)
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
values = DiscreteValues()
values[X(1)] = 0
values[A(1)] = 1
values[X(2)] = 1
self.assertAlmostEqual(lookup_x1_a1_x2(values), 100)
lookup_a2_x3 = lookup[X(3)].conditional()
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
sum_x2 = lookup_a2_x3.sum(2)
values = DiscreteValues()
values[X(2)] = 0
self.assertAlmostEqual(sum_x2(values), 0)
values[X(2)] = 1
self.assertAlmostEqual(sum_x2(values), 10)
values[X(2)] = 2
self.assertAlmostEqual(sum_x2(values), 20)
assert lookup_a2_x3.nrFrontals() == 2
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
values = DiscreteValues()
values[X(2)] = 1
values[A(2)] = 1
values[X(3)] = 2
self.assertAlmostEqual(lookup_a2_x3(values), 10)
if __name__ == "__main__":
unittest.main()

View File

@ -4,18 +4,44 @@ Authors: John Lambert
"""
import unittest
from typing import Dict, Tuple
import gtsam
import numpy as np
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)
from gtsam.gtsfm import Keypoints
from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
SfmMeasurementVector, SfmTrack2d)
class TestDsfTrackGenerator(GtsamTestCase):
"""Tests for DsfTrackGenerator."""
def test_generate_tracks_from_pairwise_matches_nontransitive(
self,
) -> None:
"""Tests DSF for non-transitive matches.
Test will result in no tracks since nontransitive tracks are naively
discarded by DSF.
"""
keypoints = get_dummy_keypoints_list()
nontransitive_matches = get_nontransitive_matches()
# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding keypoint indices (k1,k2).
matches = MatchIndicesMap()
for (i1, i2), correspondences in nontransitive_matches.items():
matches[IndexPair(i1, i2)] = correspondences
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches,
keypoints,
verbose=True,
)
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
def test_track_generation(self) -> None:
"""Ensures that DSF generates three tracks from measurements
in 3 images (H=200,W=400)."""
@ -23,20 +49,20 @@ class TestDsfTrackGenerator(GtsamTestCase):
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
keypoints_list = KeypointsVector()
keypoints_list.append(kps_i0)
keypoints_list.append(kps_i1)
keypoints_list.append(kps_i2)
keypoints = KeypointsVector()
keypoints.append(kps_i0)
keypoints.append(kps_i1)
keypoints.append(kps_i2)
# For each image pair (i1,i2), we provide a (K,2) matrix
# of corresponding image indices (k1,k2).
matches_dict = MatchIndicesMap()
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
matches = MatchIndicesMap()
matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
matches_dict,
keypoints_list,
matches,
keypoints,
verbose=False,
)
assert len(tracks) == 3
@ -93,5 +119,71 @@ class TestSfmTrack2d(GtsamTestCase):
assert track.numberMeasurements() == 1
def get_dummy_keypoints_list() -> KeypointsVector:
"""Generate a list of dummy keypoints for testing."""
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
img2_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7],
[8, 8],
]
)
img3_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
[6, 6],
[7, 7],
[8, 8],
[9, 9],
[10, 10],
]
)
img4_kp_coords = np.array(
[
[1, 1.],
[2, 2],
[3, 3],
[4, 4],
[5, 5],
]
)
keypoints = KeypointsVector()
keypoints.append(Keypoints(coordinates=img1_kp_coords))
keypoints.append(Keypoints(coordinates=img2_kp_coords))
keypoints.append(Keypoints(coordinates=img3_kp_coords))
keypoints.append(Keypoints(coordinates=img4_kp_coords))
return keypoints
def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
(i=0, k=0) (i=0, k=1)
| \\ |
| \\ |
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
Transitivity is violated due to the match between frames 0 and 3.
"""
nontransitive_matches = {
(0, 1): np.array([[0, 2]]),
(1, 2): np.array([[2, 3]]),
(0, 2): np.array([[0, 3]]),
(0, 3): np.array([[1, 4]]),
(2, 3): np.array([[3, 4]]),
}
return nontransitive_matches
if __name__ == "__main__":
unittest.main()