Merge pull request #1619 from borglab/release/4.2
commit
4f66a491ff
|
@ -28,9 +28,9 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Algebraic Decision Trees fix the range to double
|
* An algebraic decision tree fixes the range of a DecisionTree to double.
|
||||||
* Just has some nice constructors and some syntactic sugar
|
* Just has some nice constructors and some syntactic sugar.
|
||||||
* TODO: consider eliminating this class altogether?
|
* TODO(dellaert): consider eliminating this class altogether?
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
|
@ -80,12 +80,47 @@ namespace gtsam {
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||||
: Base(label, y1, y2) {}
|
: Base(label, y1, y2) {}
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/**
|
||||||
|
* @brief Create a new leaf function splitting on a variable
|
||||||
|
*
|
||||||
|
* @param labelC: The label with cardinality 2
|
||||||
|
* @param y1: The value for the first key
|
||||||
|
* @param y2: The value for the second key
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* std::pair<string, size_t> A {"a", 2};
|
||||||
|
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
|
||||||
|
* @endcode
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||||
double y2)
|
double y2)
|
||||||
: Base(labelC, y1, y2) {}
|
: Base(labelC, y1, y2) {}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/**
|
||||||
|
* @brief Create from keys with cardinalities and a vector table
|
||||||
|
*
|
||||||
|
* @param labelCs: The keys, with cardinalities, given as pairs
|
||||||
|
* @param ys: The vector table
|
||||||
|
*
|
||||||
|
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
|
||||||
|
* respectively, and a vector table of size 12:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
|
* const vector<double> cpt{
|
||||||
|
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
|
||||||
|
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
|
||||||
|
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
|
||||||
|
* @endcode
|
||||||
|
* The table is given in the following order:
|
||||||
|
* A=0, B=0, C=0
|
||||||
|
* A=0, B=0, C=1
|
||||||
|
* ...
|
||||||
|
* A=1, B=1, C=1
|
||||||
|
* Hence, the first line in the table is for A==0, and the second for A==1.
|
||||||
|
* In each line, the first two entries are for B==0, the next two for B==1,
|
||||||
|
* and the last two for B==2. Each pair is for a C value of 0 and 1.
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs,
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
const std::vector<double>& ys) {
|
const std::vector<double>& ys) {
|
||||||
|
@ -93,7 +128,14 @@ namespace gtsam {
|
||||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/**
|
||||||
|
* @brief Create from keys and string table
|
||||||
|
*
|
||||||
|
* @param labelCs: The keys, with cardinalities, given as pairs
|
||||||
|
* @param table: The string table, given as a string of doubles.
|
||||||
|
*
|
||||||
|
* @note Table needs to be in same order as the vector table in the other constructor.
|
||||||
|
*/
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs,
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
@ -108,7 +150,13 @@ namespace gtsam {
|
||||||
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/**
|
||||||
|
* @brief Create a range of decision trees, splitting on a single variable.
|
||||||
|
*
|
||||||
|
* @param begin: Iterator to beginning of a range of decision trees
|
||||||
|
* @param end: Iterator to end of a range of decision trees
|
||||||
|
* @param label: The label to split on
|
||||||
|
*/
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||||
: Base(nullptr) {
|
: Base(nullptr) {
|
||||||
|
|
|
@ -622,7 +622,7 @@ namespace gtsam {
|
||||||
// B=1
|
// B=1
|
||||||
// A=0: 3
|
// A=0: 3
|
||||||
// A=1: 4
|
// A=1: 4
|
||||||
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce
|
// Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
|
||||||
// exactly the same tree as above: the highest label is always the root.
|
// exactly the same tree as above: the highest label is always the root.
|
||||||
// However, it will be *way* faster if labels are given highest to lowest.
|
// However, it will be *way* faster if labels are given highest to lowest.
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
|
|
|
@ -37,9 +37,23 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Decision Tree
|
* @brief a decision tree is a function from assignments to values.
|
||||||
* L = label for variables
|
* @tparam L label for variables
|
||||||
* Y = function range (any algebra), e.g., bool, int, double
|
* @tparam Y function range (any algebra), e.g., bool, int, double
|
||||||
|
*
|
||||||
|
* After creating a decision tree on some variables, the tree can be evaluated
|
||||||
|
* on an assignment to those variables. Example:
|
||||||
|
*
|
||||||
|
* @code{.cpp}
|
||||||
|
* // Create a decision stump one one variable 'a' with values 10 and 20.
|
||||||
|
* DecisionTree<char, int> tree('a', 10, 20);
|
||||||
|
*
|
||||||
|
* // Evaluate the tree on an assignment to the variable.
|
||||||
|
* int value0 = tree({{'a', 0}}); // value0 = 10
|
||||||
|
* int value1 = tree({{'a', 1}}); // value1 = 20
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* More examples can be found in testDecisionTree.cpp
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
|
@ -132,7 +146,8 @@ namespace gtsam {
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/** Internal recursive function to create from keys, cardinalities,
|
/**
|
||||||
|
* Internal recursive function to create from keys, cardinalities,
|
||||||
* and Y values
|
* and Y values
|
||||||
*/
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
|
@ -163,7 +178,13 @@ namespace gtsam {
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
explicit DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
/**
|
||||||
|
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
||||||
|
*
|
||||||
|
* @param label The variable to split on.
|
||||||
|
* @param y1 The value for the first assignment.
|
||||||
|
* @param y2 The value for the second assignment.
|
||||||
|
*/
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
|
||||||
/** Allow Label+Cardinality for convenience */
|
/** Allow Label+Cardinality for convenience */
|
||||||
|
|
|
@ -63,11 +63,46 @@ namespace gtsam {
|
||||||
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from doubles */
|
/**
|
||||||
|
* @brief Constructor from doubles
|
||||||
|
*
|
||||||
|
* @param keys The discrete keys.
|
||||||
|
* @param table The table of values.
|
||||||
|
*
|
||||||
|
* @throw std::invalid_argument if the size of `table` does not match the
|
||||||
|
* number of assignments.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey X(0,2), Y(1,3);
|
||||||
|
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
|
||||||
|
* DecisionTreeFactor f1({X, Y}, table);
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* The values in the table should be laid out so that the first key varies
|
||||||
|
* the slowest, and the last key the fastest.
|
||||||
|
*/
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys,
|
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
const std::vector<double>& table);
|
const std::vector<double>& table);
|
||||||
|
|
||||||
/** Constructor from string */
|
/**
|
||||||
|
* @brief Constructor from string
|
||||||
|
*
|
||||||
|
* @param keys The discrete keys.
|
||||||
|
* @param table The table of values.
|
||||||
|
*
|
||||||
|
* @throw std::invalid_argument if the size of `table` does not match the
|
||||||
|
* number of assignments.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* @code{.cpp}
|
||||||
|
* DiscreteKey X(0,2), Y(1,3);
|
||||||
|
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* The values in the table should be laid out so that the first key varies
|
||||||
|
* the slowest, and the last key the fastest.
|
||||||
|
*/
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||||
|
|
||||||
/// Single-key specialization
|
/// Single-key specialization
|
||||||
|
|
|
@ -59,6 +59,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
|
||||||
|
|
||||||
//** evaluate conditional probability of subtree for given DiscreteValues */
|
//** evaluate conditional probability of subtree for given DiscreteValues */
|
||||||
double evaluate(const DiscreteValues& values) const;
|
double evaluate(const DiscreteValues& values) const;
|
||||||
|
|
||||||
|
//** (Preferred) sugar for the above for given DiscreteValues */
|
||||||
|
double operator()(const DiscreteValues& values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -43,15 +43,29 @@ class DiscreteJunctionTree;
|
||||||
/**
|
/**
|
||||||
* @brief Main elimination function for DiscreteFactorGraph.
|
* @brief Main elimination function for DiscreteFactorGraph.
|
||||||
*
|
*
|
||||||
* @param factors
|
* @param factors The factor graph to eliminate.
|
||||||
* @param keys
|
* @param frontalKeys An ordering for which variables to eliminate.
|
||||||
* @return GTSAM_EXPORT
|
* @return A pair of the resulting conditional and the separator factor.
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
|
GTSAM_EXPORT
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||||
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Alternate elimination function for that creates non-normalized lookup tables.
|
||||||
|
*
|
||||||
|
* @param factors The factor graph to eliminate.
|
||||||
|
* @param frontalKeys An ordering for which variables to eliminate.
|
||||||
|
* @return A pair of the resulting lookup table and the separator factor.
|
||||||
|
* @ingroup discrete
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
|
||||||
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
template<> struct EliminationTraits<DiscreteFactorGraph>
|
template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
{
|
{
|
||||||
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
||||||
|
@ -61,12 +75,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
||||||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||||
|
|
||||||
/// The default dense elimination function
|
/// The default dense elimination function
|
||||||
static std::pair<boost::shared_ptr<ConditionalType>,
|
static std::pair<boost::shared_ptr<ConditionalType>,
|
||||||
boost::shared_ptr<FactorType> >
|
boost::shared_ptr<FactorType> >
|
||||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||||
return EliminateDiscrete(factors, keys);
|
return EliminateDiscrete(factors, keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The default ordering generation function
|
/// The default ordering generation function
|
||||||
static Ordering DefaultOrderingFunc(
|
static Ordering DefaultOrderingFunc(
|
||||||
const FactorGraphType& graph,
|
const FactorGraphType& graph,
|
||||||
|
@ -75,7 +91,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
/**
|
/**
|
||||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||||
* Factor == DiscreteFactor
|
* Factor == DiscreteFactor
|
||||||
|
@ -109,8 +124,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
|
|
||||||
/** Implicit copy/downcast constructor to override explicit template container
|
/** Implicit copy/downcast constructor to override explicit template container
|
||||||
* constructor */
|
* constructor */
|
||||||
template <class DERIVEDFACTOR>
|
template <class DERIVED_FACTOR>
|
||||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
virtual ~DiscreteFactorGraph() {}
|
virtual ~DiscreteFactorGraph() {}
|
||||||
|
@ -231,10 +246,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
/// @}
|
/// @}
|
||||||
}; // \ DiscreteFactorGraph
|
}; // \ DiscreteFactorGraph
|
||||||
|
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
|
||||||
const Ordering& frontalKeys);
|
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};
|
||||||
|
|
|
@ -66,4 +66,6 @@ namespace gtsam {
|
||||||
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
|
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// typedef for wrapper:
|
||||||
|
using DiscreteCluster = DiscreteJunctionTree::Cluster;
|
||||||
}
|
}
|
||||||
|
|
|
@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Free version of CartesianProduct.
|
||||||
|
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
|
||||||
|
return DiscreteValues::CartesianProduct(keys);
|
||||||
|
}
|
||||||
|
|
||||||
/// Free version of markdown.
|
/// Free version of markdown.
|
||||||
std::string markdown(const DiscreteValues& values,
|
std::string markdown(const DiscreteValues& values,
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
|
|
|
@ -17,6 +17,8 @@ class DiscreteKeys {
|
||||||
};
|
};
|
||||||
|
|
||||||
// DiscreteValues is added in specializations/discrete.h as a std::map
|
// DiscreteValues is added in specializations/discrete.h as a std::map
|
||||||
|
std::vector<gtsam::DiscreteValues> cartesianProduct(
|
||||||
|
const gtsam::DiscreteKeys& keys);
|
||||||
string markdown(
|
string markdown(
|
||||||
const gtsam::DiscreteValues& values,
|
const gtsam::DiscreteValues& values,
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
|
||||||
|
@ -31,13 +33,11 @@ string html(const gtsam::DiscreteValues& values,
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names);
|
std::map<gtsam::Key, std::vector<std::string>> names);
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
class DiscreteFactor {
|
virtual class DiscreteFactor : gtsam::Factor {
|
||||||
void print(string s = "DiscreteFactor\n",
|
void print(string s = "DiscreteFactor\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
bool empty() const;
|
|
||||||
size_t size() const;
|
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -49,7 +49,12 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
const std::vector<double>& spec);
|
const std::vector<double>& spec);
|
||||||
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
|
||||||
|
|
||||||
|
DecisionTreeFactor(const gtsam::DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
|
||||||
|
|
||||||
|
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
|
||||||
|
|
||||||
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
DecisionTreeFactor(const gtsam::DiscreteConditional& c);
|
||||||
|
@ -59,6 +64,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||||
size_t cardinality(gtsam::Key j) const;
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
@ -66,6 +73,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
|
||||||
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
|
||||||
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
|
||||||
|
gtsam::DecisionTreeFactor* max(const gtsam::Ordering& keys) const;
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
@ -203,10 +211,16 @@ class DiscreteBayesTreeClique {
|
||||||
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
|
||||||
const gtsam::DiscreteConditional* conditional() const;
|
const gtsam::DiscreteConditional* conditional() const;
|
||||||
bool isRoot() const;
|
bool isRoot() const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
|
||||||
|
void print(string s = "DiscreteBayesTreeClique",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printSignature(
|
void printSignature(
|
||||||
const string& s = "Clique: ",
|
const string& s = "Clique: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
double evaluate(const gtsam::DiscreteValues& values) const;
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DiscreteBayesTree {
|
class DiscreteBayesTree {
|
||||||
|
@ -220,6 +234,9 @@ class DiscreteBayesTree {
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
const DiscreteBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
|
double evaluate(const gtsam::DiscreteValues& values) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
|
||||||
string dot(const gtsam::KeyFormatter& keyFormatter =
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void saveGraph(string s,
|
void saveGraph(string s,
|
||||||
|
@ -242,9 +259,9 @@ class DiscreteBayesTree {
|
||||||
class DiscreteLookupTable : gtsam::DiscreteConditional{
|
class DiscreteLookupTable : gtsam::DiscreteConditional{
|
||||||
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
|
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
|
||||||
const gtsam::DecisionTreeFactor::ADT& potentials);
|
const gtsam::DecisionTreeFactor::ADT& potentials);
|
||||||
void print(
|
void print(string s = "Discrete Lookup Table: ",
|
||||||
const std::string& s = "Discrete Lookup Table: ",
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -263,6 +280,14 @@ class DiscreteLookupDAG {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||||
|
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
|
||||||
|
const gtsam::Ordering& frontalKeys);
|
||||||
|
|
||||||
|
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
|
||||||
|
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
|
||||||
|
const gtsam::Ordering& frontalKeys);
|
||||||
|
|
||||||
class DiscreteFactorGraph {
|
class DiscreteFactorGraph {
|
||||||
DiscreteFactorGraph();
|
DiscreteFactorGraph();
|
||||||
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
|
||||||
|
@ -277,6 +302,7 @@ class DiscreteFactorGraph {
|
||||||
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
|
||||||
void add(const gtsam::DiscreteKeys& keys, string spec);
|
void add(const gtsam::DiscreteKeys& keys, string spec);
|
||||||
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
|
||||||
|
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
|
@ -290,25 +316,46 @@ class DiscreteFactorGraph {
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet sumProduct();
|
gtsam::DiscreteBayesNet sumProduct(
|
||||||
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteLookupDAG maxProduct();
|
gtsam::DiscreteLookupDAG maxProduct(
|
||||||
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential();
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
|
gtsam::Ordering::OrderingType type,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesNet* eliminateSequential(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
eliminatePartialSequential(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialSequential(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal();
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type);
|
gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
|
||||||
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
gtsam::Ordering::OrderingType type,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering);
|
||||||
|
gtsam::DiscreteBayesTree* eliminateMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
|
||||||
|
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
|
||||||
|
eliminatePartialMultifrontal(
|
||||||
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::DiscreteFactorGraph::Eliminate& function);
|
||||||
|
|
||||||
string dot(
|
string dot(
|
||||||
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
|
||||||
|
@ -328,4 +375,41 @@ class DiscreteFactorGraph {
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
|
|
||||||
|
class DiscreteEliminationTree {
|
||||||
|
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
|
||||||
|
const gtsam::VariableIndex& structure,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
void print(
|
||||||
|
string name = "EliminationTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::DiscreteEliminationTree& other,
|
||||||
|
double tol = 1e-9) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
|
||||||
|
class DiscreteCluster {
|
||||||
|
gtsam::Ordering orderedFrontalKeys;
|
||||||
|
gtsam::DiscreteFactorGraph factors;
|
||||||
|
const gtsam::DiscreteCluster& operator[](size_t i) const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DiscreteJunctionTree {
|
||||||
|
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
|
||||||
|
void print(
|
||||||
|
string name = "JunctionTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
size_t nrRoots() const;
|
||||||
|
const gtsam::DiscreteCluster& operator[](size_t i) const;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -71,6 +71,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test char labels and int range
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
// Create a decision stump one one variable 'a' with values 10 and 20.
|
||||||
|
TEST(DecisionTree, Constructor) {
|
||||||
|
DecisionTree<char, int> tree('a', 10, 20);
|
||||||
|
|
||||||
|
// Evaluate the tree on an assignment to the variable.
|
||||||
|
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test string labels and int range
|
// Test string labels and int range
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
@ -114,18 +127,47 @@ struct Ring {
|
||||||
static inline int mul(const int& a, const int& b) { return a * b; }
|
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Check that creating decision trees respects key order.
|
||||||
|
TEST(DecisionTree, ConstructorOrder) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
const std::vector<int> ys1 = {1, 2, 3, 4};
|
||||||
|
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
|
||||||
|
|
||||||
|
const std::vector<int> ys2 = {1, 3, 2, 4};
|
||||||
|
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
|
||||||
|
|
||||||
|
// Both trees will be the same, tree is order from high to low labels.
|
||||||
|
// Choice(B)
|
||||||
|
// 0 Choice(A)
|
||||||
|
// 0 0 Leaf 1
|
||||||
|
// 0 1 Leaf 2
|
||||||
|
// 1 Choice(A)
|
||||||
|
// 1 0 Leaf 3
|
||||||
|
// 1 1 Leaf 4
|
||||||
|
|
||||||
|
EXPECT(tree2.equals(tree1));
|
||||||
|
|
||||||
|
// Check the values are as expected by calling the () operator:
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DecisionTree, example) {
|
TEST(DecisionTree, Example) {
|
||||||
// Create labels
|
// Create labels
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
|
|
||||||
// create a value
|
// Create assignments using brace initialization:
|
||||||
Assignment<string> x00, x01, x10, x11;
|
Assignment<string> x00{{A, 0}, {B, 0}};
|
||||||
x00[A] = 0, x00[B] = 0;
|
Assignment<string> x01{{A, 0}, {B, 1}};
|
||||||
x01[A] = 0, x01[B] = 1;
|
Assignment<string> x10{{A, 1}, {B, 0}};
|
||||||
x10[A] = 1, x10[B] = 0;
|
Assignment<string> x11{{A, 1}, {B, 1}};
|
||||||
x11[A] = 1, x11[B] = 1;
|
|
||||||
|
|
||||||
// empty
|
// empty
|
||||||
DT empty;
|
DT empty;
|
||||||
|
@ -237,8 +279,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
StringBoolTree f2(f1, bool_of_int);
|
StringBoolTree f2(f1, bool_of_int);
|
||||||
|
|
||||||
// Check a value
|
// Check a value
|
||||||
Assignment<string> x00;
|
Assignment<string> x00 {{A, 0}, {B, 0}};
|
||||||
x00["A"] = 0, x00["B"] = 0;
|
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,10 +303,11 @@ TEST(DecisionTree, ConvertBoth) {
|
||||||
|
|
||||||
// Check some values
|
// Check some values
|
||||||
Assignment<Label> x00, x01, x10, x11;
|
Assignment<Label> x00, x01, x10, x11;
|
||||||
x00[X] = 0, x00[Y] = 0;
|
x00 = {{X, 0}, {Y, 0}};
|
||||||
x01[X] = 0, x01[Y] = 1;
|
x01 = {{X, 0}, {Y, 1}};
|
||||||
x10[X] = 1, x10[Y] = 0;
|
x10 = {{X, 1}, {Y, 0}};
|
||||||
x11[X] = 1, x11[Y] = 1;
|
x11 = {{X, 1}, {Y, 1}};
|
||||||
|
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
EXPECT(!f2(x01));
|
EXPECT(!f2(x01));
|
||||||
EXPECT(f2(x10));
|
EXPECT(f2(x10));
|
||||||
|
|
|
@ -27,6 +27,18 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DecisionTreeFactor, ConstructorsMatch) {
|
||||||
|
// Declare two keys
|
||||||
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
||||||
|
// Create with vector and with string
|
||||||
|
const std::vector<double> table {2, 5, 3, 6, 4, 7};
|
||||||
|
DecisionTreeFactor f1({X, Y}, table);
|
||||||
|
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
|
||||||
|
EXPECT(assert_equal(f1, f2));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, constructors)
|
TEST( DecisionTreeFactor, constructors)
|
||||||
{
|
{
|
||||||
|
@ -41,16 +53,13 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
EXPECT_LONGS_EQUAL(2,f2.size());
|
||||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
EXPECT_LONGS_EQUAL(3,f3.size());
|
||||||
|
|
||||||
DiscreteValues values;
|
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
|
||||||
values[0] = 1; // x
|
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
|
||||||
values[1] = 2; // y
|
EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
|
||||||
values[2] = 1; // z
|
EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
|
|
||||||
|
|
||||||
// Assert that error = -log(value)
|
// Assert that error = -log(value)
|
||||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -16,23 +16,24 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
static constexpr bool debug = false;
|
static constexpr bool debug = false;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
struct TestFixture {
|
struct TestFixture {
|
||||||
vector<DiscreteKey> keys;
|
DiscreteKeys keys;
|
||||||
|
std::vector<DiscreteValues> assignments;
|
||||||
DiscreteBayesNet bayesNet;
|
DiscreteBayesNet bayesNet;
|
||||||
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
boost::shared_ptr<DiscreteBayesTree> bayesTree;
|
||||||
|
|
||||||
|
@ -47,6 +48,9 @@ struct TestFixture {
|
||||||
keys.push_back(key_i);
|
keys.push_back(key_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enumerate all assignments.
|
||||||
|
assignments = DiscreteValues::CartesianProduct(keys);
|
||||||
|
|
||||||
// Create thin-tree Bayesnet.
|
// Create thin-tree Bayesnet.
|
||||||
bayesNet.add(keys[14] % "1/3");
|
bayesNet.add(keys[14] % "1/3");
|
||||||
|
|
||||||
|
@ -74,9 +78,9 @@ struct TestFixture {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// Check that BN and BT give the same answer on all configurations
|
||||||
TEST(DiscreteBayesTree, ThinTree) {
|
TEST(DiscreteBayesTree, ThinTree) {
|
||||||
const TestFixture self;
|
TestFixture self;
|
||||||
const auto& keys = self.keys;
|
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
GTSAM_PRINT(self.bayesNet);
|
GTSAM_PRINT(self.bayesNet);
|
||||||
|
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto R = self.bayesTree->roots().front();
|
for (const auto& x : self.assignments) {
|
||||||
|
|
||||||
// Check whether BN and BT give the same answer on all configurations
|
|
||||||
auto allPosbValues = DiscreteValues::CartesianProduct(
|
|
||||||
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
|
|
||||||
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
|
|
||||||
keys[14]);
|
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
|
||||||
DiscreteValues x = allPosbValues[i];
|
|
||||||
double expected = self.bayesNet.evaluate(x);
|
double expected = self.bayesNet.evaluate(x);
|
||||||
double actual = self.bayesTree->evaluate(x);
|
double actual = self.bayesTree->evaluate(x);
|
||||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate all some marginals for DiscreteValues==all1
|
/* ************************************************************************* */
|
||||||
Vector marginals = Vector::Zero(15);
|
// Check calculation of separator marginals
|
||||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
TEST(DiscreteBayesTree, SeparatorMarginals) {
|
||||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
TestFixture self;
|
||||||
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
|
|
||||||
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
double marginal_14 = 0, joint_8_12 = 0;
|
||||||
DiscreteValues x = allPosbValues[i];
|
for (auto& x : self.assignments) {
|
||||||
double px = self.bayesTree->evaluate(x);
|
double px = self.bayesTree->evaluate(x);
|
||||||
for (size_t i = 0; i < 15; i++)
|
|
||||||
if (x[i]) marginals[i] += px;
|
|
||||||
if (x[12] && x[14]) {
|
|
||||||
joint_12_14 += px;
|
|
||||||
if (x[9]) joint_9_12_14 += px;
|
|
||||||
if (x[8]) joint_8_12_14 += px;
|
|
||||||
}
|
|
||||||
if (x[8] && x[12]) joint_8_12 += px;
|
if (x[8] && x[12]) joint_8_12 += px;
|
||||||
if (x[2]) {
|
if (x[14]) marginal_14 += px;
|
||||||
if (x[8]) joint82 += px;
|
|
||||||
if (x[1]) joint12 += px;
|
|
||||||
}
|
}
|
||||||
if (x[4]) {
|
DiscreteValues all1 = self.assignments.back();
|
||||||
if (x[2]) joint24 += px;
|
|
||||||
if (x[5]) joint45 += px;
|
// check separator marginal P(S0)
|
||||||
if (x[6]) joint46 += px;
|
auto clique = (*self.bayesTree)[0];
|
||||||
if (x[11]) joint_4_11 += px;
|
DiscreteFactorGraph separatorMarginal0 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||||
|
|
||||||
|
// check separator marginal P(S9), should be P(14)
|
||||||
|
clique = (*self.bayesTree)[9];
|
||||||
|
DiscreteFactorGraph separatorMarginal9 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
|
||||||
|
|
||||||
|
// check separator marginal of root, should be empty
|
||||||
|
clique = (*self.bayesTree)[11];
|
||||||
|
DiscreteFactorGraph separatorMarginal11 =
|
||||||
|
clique->separatorMarginal(EliminateDiscrete);
|
||||||
|
LONGS_EQUAL(0, separatorMarginal11.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check shortcuts in the tree
|
||||||
|
TEST(DiscreteBayesTree, Shortcuts) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
|
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
|
||||||
|
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
|
||||||
|
for (auto& x : self.assignments) {
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
if (x[11] && x[13]) {
|
if (x[11] && x[13]) {
|
||||||
joint_11_13 += px;
|
joint_11_13 += px;
|
||||||
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
if (x[8] && x[12]) joint_8_11_12_13 += px;
|
||||||
|
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
DiscreteValues all1 = allPosbValues.back();
|
DiscreteValues all1 = self.assignments.back();
|
||||||
|
|
||||||
// check separator marginal P(S0)
|
auto R = self.bayesTree->roots().front();
|
||||||
auto clique = (*self.bayesTree)[0];
|
|
||||||
DiscreteFactorGraph separatorMarginal0 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
|
||||||
|
|
||||||
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
|
||||||
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
|
||||||
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
|
||||||
|
|
||||||
// check separator marginal P(S9), should be P(14)
|
|
||||||
clique = (*self.bayesTree)[9];
|
|
||||||
DiscreteFactorGraph separatorMarginal9 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
|
||||||
|
|
||||||
// check separator marginal of root, should be empty
|
|
||||||
clique = (*self.bayesTree)[11];
|
|
||||||
DiscreteFactorGraph separatorMarginal11 =
|
|
||||||
clique->separatorMarginal(EliminateDiscrete);
|
|
||||||
LONGS_EQUAL(0, separatorMarginal11.size());
|
|
||||||
|
|
||||||
// check shortcut P(S9||R) to root
|
// check shortcut P(S9||R) to root
|
||||||
clique = (*self.bayesTree)[9];
|
auto clique = (*self.bayesTree)[9];
|
||||||
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
|
||||||
LONGS_EQUAL(1, shortcut.size());
|
LONGS_EQUAL(1, shortcut.size());
|
||||||
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
|
||||||
|
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
shortcut.print("shortcut:");
|
shortcut.print("shortcut:");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check all marginals
|
||||||
|
TEST(DiscreteBayesTree, MarginalFactors) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
Vector marginals = Vector::Zero(15);
|
||||||
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
|
for (size_t i = 0; i < 15; i++)
|
||||||
|
if (x[i]) marginals[i] += px;
|
||||||
|
}
|
||||||
|
|
||||||
// Check all marginals
|
// Check all marginals
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteValues all1 = self.assignments.back();
|
||||||
for (size_t i = 0; i < 15; i++) {
|
for (size_t i = 0; i < 15; i++) {
|
||||||
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
|
||||||
double actual = (*marginalFactor)(all1);
|
double actual = (*marginalFactor)(all1);
|
||||||
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check a number of joint marginals.
|
||||||
|
TEST(DiscreteBayesTree, Joints) {
|
||||||
|
TestFixture self;
|
||||||
|
|
||||||
|
// Calculate some marginals for DiscreteValues==all1
|
||||||
|
Vector marginals = Vector::Zero(15);
|
||||||
|
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||||
|
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||||
|
for (size_t i = 0; i < self.assignments.size(); ++i) {
|
||||||
|
DiscreteValues& x = self.assignments[i];
|
||||||
|
double px = self.bayesTree->evaluate(x);
|
||||||
|
for (size_t i = 0; i < 15; i++)
|
||||||
|
if (x[i]) marginals[i] += px;
|
||||||
|
if (x[12] && x[14]) {
|
||||||
|
joint_12_14 += px;
|
||||||
|
if (x[9]) joint_9_12_14 += px;
|
||||||
|
if (x[8]) joint_8_12_14 += px;
|
||||||
|
}
|
||||||
|
if (x[2]) {
|
||||||
|
if (x[8]) joint82 += px;
|
||||||
|
if (x[1]) joint12 += px;
|
||||||
|
}
|
||||||
|
if (x[4]) {
|
||||||
|
if (x[2]) joint24 += px;
|
||||||
|
if (x[5]) joint45 += px;
|
||||||
|
if (x[6]) joint46 += px;
|
||||||
|
if (x[11]) joint_4_11 += px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// regression tests:
|
||||||
|
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
|
||||||
|
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
|
||||||
|
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
|
||||||
|
|
||||||
|
DiscreteValues all1 = self.assignments.back();
|
||||||
DiscreteBayesNet::shared_ptr actualJoint;
|
DiscreteBayesNet::shared_ptr actualJoint;
|
||||||
|
|
||||||
// Check joint P(8, 2)
|
// Check joint P(8, 2)
|
||||||
|
@ -240,8 +285,8 @@ TEST(DiscreteBayesTree, ThinTree) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteBayesTree, Dot) {
|
TEST(DiscreteBayesTree, Dot) {
|
||||||
const TestFixture self;
|
TestFixture self;
|
||||||
string actual = self.bayesTree->dot();
|
std::string actual = self.bayesTree->dot();
|
||||||
EXPECT(actual ==
|
EXPECT(actual ==
|
||||||
"digraph G{\n"
|
"digraph G{\n"
|
||||||
"0[label=\"13, 11, 6, 7\"];\n"
|
"0[label=\"13, 11, 6, 7\"];\n"
|
||||||
|
@ -268,6 +313,62 @@ TEST(DiscreteBayesTree, Dot) {
|
||||||
"}");
|
"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check that we can have a multi-frontal lookup table
|
||||||
|
TEST(DiscreteBayesTree, Lookup) {
|
||||||
|
using gtsam::symbol_shorthand::A;
|
||||||
|
using gtsam::symbol_shorthand::X;
|
||||||
|
|
||||||
|
// Make a small planning-like graph: 3 states, 2 actions
|
||||||
|
DiscreteFactorGraph graph;
|
||||||
|
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
|
||||||
|
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
|
||||||
|
|
||||||
|
// Constraint on start and goal
|
||||||
|
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
|
||||||
|
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
|
||||||
|
|
||||||
|
// Should I stay or should I go?
|
||||||
|
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||||
|
const double r = 10;
|
||||||
|
std::vector<double> table{
|
||||||
|
r, 0, 0, 0, r, 0, // x1 = 0
|
||||||
|
0, r, 0, 0, 0, r, // x1 = 1
|
||||||
|
0, 0, r, 0, 0, r // x1 = 2
|
||||||
|
};
|
||||||
|
graph.add(DiscreteKeys{x1, a1, x2}, table);
|
||||||
|
graph.add(DiscreteKeys{x2, a2, x3}, table);
|
||||||
|
|
||||||
|
// eliminate for MPE (maximum probable explanation).
|
||||||
|
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
|
||||||
|
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
|
||||||
|
|
||||||
|
// Check that the lookup table is correct
|
||||||
|
EXPECT_LONGS_EQUAL(2, lookup->size());
|
||||||
|
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
|
||||||
|
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
|
||||||
|
// check that sum is 100
|
||||||
|
DiscreteValues empty;
|
||||||
|
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
|
||||||
|
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||||
|
EXPECT_DOUBLES_EQUAL(100, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
|
||||||
|
|
||||||
|
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
|
||||||
|
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
|
||||||
|
auto sum_x2 = lookup_a2_x3->sum(2);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, (*sum_x2)({{X(2),1}}), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(20, (*sum_x2)({{X(2),2}}), 1e-9);
|
||||||
|
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
|
||||||
|
// And that the non-zero rewards are for
|
||||||
|
// x2 a2 x3 == 1 1 2
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
|
||||||
|
// x2 a2 x3 == 2 0 2
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
|
||||||
|
// x2 a2 x3 == 2 1 2
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -106,7 +106,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
result = addGaussian(result, gf);
|
result = addGaussian(result, gf);
|
||||||
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
result = gmf->add(result);
|
||||||
|
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
|
||||||
result = gm->add(result);
|
result = gm->add(result);
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
if (auto gm = hc->asMixture()) {
|
if (auto gm = hc->asMixture()) {
|
||||||
|
@ -283,17 +285,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// taking care to correct for conditional constant.
|
// taking care to correct for conditional constant.
|
||||||
|
|
||||||
// Correct for the normalization constant used up by the conditional
|
// Correct for the normalization constant used up by the conditional
|
||||||
auto correct = [&](const Result &pair) -> GaussianFactor::shared_ptr {
|
auto correct = [&](const Result &pair) {
|
||||||
const auto &factor = pair.second;
|
const auto &factor = pair.second;
|
||||||
if (!factor) return factor; // TODO(dellaert): not loving this.
|
if (!factor) return;
|
||||||
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
|
auto hf = boost::dynamic_pointer_cast<HessianFactor>(factor);
|
||||||
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
||||||
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
|
hf->constantTerm() += 2.0 * pair.first->logNormalizationConstant();
|
||||||
return hf;
|
|
||||||
};
|
};
|
||||||
|
eliminationResults.visit(correct);
|
||||||
|
|
||||||
GaussianMixtureFactor::Factors correctedFactors(eliminationResults,
|
|
||||||
correct);
|
|
||||||
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
|
const auto mixtureFactor = boost::make_shared<GaussianMixtureFactor>(
|
||||||
continuousSeparator, discreteSeparator, newFactors);
|
continuousSeparator, discreteSeparator, newFactors);
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
#include <gtsam/hybrid/MixtureFactor.h>
|
#include <gtsam/hybrid/MixtureFactor.h>
|
||||||
|
@ -69,6 +70,12 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
// If discrete-only: doesn't need linearization.
|
// If discrete-only: doesn't need linearization.
|
||||||
linearFG->push_back(f);
|
linearFG->push_back(f);
|
||||||
|
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
linearFG->push_back(gmf);
|
||||||
|
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
|
||||||
|
linearFG->push_back(gm);
|
||||||
|
} else if (dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
linearFG->push_back(f);
|
||||||
} else {
|
} else {
|
||||||
auto& fr = *f;
|
auto& fr = *f;
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
|
|
@ -23,6 +23,37 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
Ordering HybridSmoother::getOrdering(
|
||||||
|
const HybridGaussianFactorGraph &newFactors) {
|
||||||
|
HybridGaussianFactorGraph factors(hybridBayesNet());
|
||||||
|
factors += newFactors;
|
||||||
|
// Get all the discrete keys from the factors
|
||||||
|
KeySet allDiscrete = factors.discreteKeySet();
|
||||||
|
|
||||||
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
|
KeyVector newKeysDiscreteLast;
|
||||||
|
const KeySet newFactorKeys = newFactors.keys();
|
||||||
|
// Insert continuous keys first.
|
||||||
|
for (auto &k : newFactorKeys) {
|
||||||
|
if (!allDiscrete.exists(k)) {
|
||||||
|
newKeysDiscreteLast.push_back(k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert discrete keys at the end
|
||||||
|
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||||
|
std::back_inserter(newKeysDiscreteLast));
|
||||||
|
|
||||||
|
const VariableIndex index(newFactors);
|
||||||
|
|
||||||
|
// Get an ordering where the new keys are eliminated last
|
||||||
|
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||||
|
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
|
||||||
|
true);
|
||||||
|
return ordering;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||||
const Ordering &ordering,
|
const Ordering &ordering,
|
||||||
|
@ -92,7 +123,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.push_back(newConditionals);
|
graph.push_back(newConditionals);
|
||||||
// newConditionals.print("\n\n\nNew Conditionals to add back");
|
|
||||||
}
|
}
|
||||||
return {graph, hybridBayesNet};
|
return {graph, hybridBayesNet};
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,6 +50,8 @@ class HybridSmoother {
|
||||||
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
|
void update(HybridGaussianFactorGraph graph, const Ordering& ordering,
|
||||||
boost::optional<size_t> maxNrLeaves = boost::none);
|
boost::optional<size_t> maxNrLeaves = boost::none);
|
||||||
|
|
||||||
|
Ordering getOrdering(const HybridGaussianFactorGraph& newFactors);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Add conditionals from previous timestep as part of liquefication.
|
* @brief Add conditionals from previous timestep as part of liquefication.
|
||||||
*
|
*
|
||||||
|
|
|
@ -35,14 +35,11 @@ class HybridValues {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
virtual class HybridFactor {
|
virtual class HybridFactor : gtsam::Factor {
|
||||||
void print(string s = "HybridFactor\n",
|
void print(string s = "HybridFactor\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
|
||||||
bool empty() const;
|
|
||||||
size_t size() const;
|
|
||||||
gtsam::KeyVector keys() const;
|
|
||||||
|
|
||||||
// Standard interface:
|
// Standard interface:
|
||||||
double error(const gtsam::HybridValues &values) const;
|
double error(const gtsam::HybridValues &values) const;
|
||||||
|
|
|
@ -93,6 +93,7 @@ TEST(GaussianMixtureFactor, Sum) {
|
||||||
EXPECT(actual.at(1) == f22);
|
EXPECT(actual.at(1) == f22);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
TEST(GaussianMixtureFactor, Printing) {
|
TEST(GaussianMixtureFactor, Printing) {
|
||||||
DiscreteKey m1(1, 2);
|
DiscreteKey m1(1, 2);
|
||||||
auto A1 = Matrix::Zero(2, 1);
|
auto A1 = Matrix::Zero(2, 1);
|
||||||
|
@ -136,6 +137,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
TEST(GaussianMixtureFactor, GaussianMixture) {
|
TEST(GaussianMixtureFactor, GaussianMixture) {
|
||||||
KeyVector keys;
|
KeyVector keys;
|
||||||
keys.push_back(X(0));
|
keys.push_back(X(0));
|
||||||
|
|
|
@ -612,7 +612,6 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||||
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
||||||
// assignment.
|
// assignment.
|
||||||
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||||
using symbol_shorthand::Z;
|
|
||||||
const int num_measurements = 1;
|
const int num_measurements = 1;
|
||||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
auto fg = tiny::createHybridGaussianFactorGraph(
|
||||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
||||||
|
@ -694,7 +693,6 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that eliminating tiny net with 1 measurement yields correct result.
|
// Check that eliminating tiny net with 1 measurement yields correct result.
|
||||||
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
using symbol_shorthand::Z;
|
|
||||||
const int num_measurements = 1;
|
const int num_measurements = 1;
|
||||||
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
||||||
auto bn = tiny::createHybridBayesNet(num_measurements);
|
auto bn = tiny::createHybridBayesNet(num_measurements);
|
||||||
|
@ -726,11 +724,67 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Check that eliminating tiny net with 1 measurement with mode order swapped
|
||||||
|
// yields correct result.
|
||||||
|
TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
||||||
|
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
||||||
|
|
||||||
|
// Create mode key: 1 is low-noise, 0 is high-noise.
|
||||||
|
const DiscreteKey mode{M(0), 2};
|
||||||
|
HybridBayesNet bn;
|
||||||
|
|
||||||
|
// Create Gaussian mixture z_0 = x0 + noise for each measurement.
|
||||||
|
bn.emplace_back(new GaussianMixture(
|
||||||
|
{Z(0)}, {X(0)}, {mode},
|
||||||
|
{GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3),
|
||||||
|
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1,
|
||||||
|
0.5)}));
|
||||||
|
|
||||||
|
// Create prior on X(0).
|
||||||
|
bn.push_back(
|
||||||
|
GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));
|
||||||
|
|
||||||
|
// Add prior on mode.
|
||||||
|
bn.emplace_back(new DiscreteConditional(mode, "1/1"));
|
||||||
|
|
||||||
|
// bn.print();
|
||||||
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
|
// fg.print();
|
||||||
|
|
||||||
|
EXPECT(ratioTest(bn, measurements, fg));
|
||||||
|
|
||||||
|
// Create expected Bayes Net:
|
||||||
|
HybridBayesNet expectedBayesNet;
|
||||||
|
|
||||||
|
// Create Gaussian mixture on X(0).
|
||||||
|
// regression, but mean checked to be 5.0 in both cases:
|
||||||
|
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||||
|
X(0), Vector1(10.1379), I_1x1 * 2.02759),
|
||||||
|
conditional1 = boost::make_shared<GaussianConditional>(
|
||||||
|
X(0), Vector1(14.1421), I_1x1 * 2.82843);
|
||||||
|
expectedBayesNet.emplace_back(
|
||||||
|
new GaussianMixture({X(0)}, {}, {mode}, {conditional0, conditional1}));
|
||||||
|
|
||||||
|
// Add prior on mode.
|
||||||
|
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "1/1"));
|
||||||
|
|
||||||
|
// Test elimination
|
||||||
|
const auto posterior = fg.eliminateSequential();
|
||||||
|
// EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||||
|
|
||||||
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
|
|
||||||
|
// posterior->print();
|
||||||
|
// posterior->optimize().print();
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that eliminating tiny net with 2 measurements yields correct result.
|
// Check that eliminating tiny net with 2 measurements yields correct result.
|
||||||
TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
||||||
using symbol_shorthand::Z;
|
|
||||||
const int num_measurements = 2;
|
const int num_measurements = 2;
|
||||||
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
|
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
|
||||||
auto bn = tiny::createHybridBayesNet(num_measurements);
|
auto bn = tiny::createHybridBayesNet(num_measurements);
|
||||||
|
@ -764,7 +818,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
// Test eliminating tiny net with 1 mode per measurement.
|
// Test eliminating tiny net with 1 mode per measurement.
|
||||||
TEST(HybridGaussianFactorGraph, EliminateTiny22) {
|
TEST(HybridGaussianFactorGraph, EliminateTiny22) {
|
||||||
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
||||||
using symbol_shorthand::Z;
|
|
||||||
const int num_measurements = 2;
|
const int num_measurements = 2;
|
||||||
const bool manyModes = true;
|
const bool manyModes = true;
|
||||||
|
|
||||||
|
|
|
@ -140,9 +140,15 @@ namespace gtsam {
|
||||||
/** Access the conditional */
|
/** Access the conditional */
|
||||||
const sharedConditional& conditional() const { return conditional_; }
|
const sharedConditional& conditional() const { return conditional_; }
|
||||||
|
|
||||||
/** is this the root of a Bayes tree ? */
|
/// Return true if this clique is the root of a Bayes tree.
|
||||||
inline bool isRoot() const { return parent_.expired(); }
|
inline bool isRoot() const { return parent_.expired(); }
|
||||||
|
|
||||||
|
/// Return the number of children.
|
||||||
|
size_t nrChildren() const { return children.size(); }
|
||||||
|
|
||||||
|
/// Return the child at index i.
|
||||||
|
const derived_ptr operator[](size_t i) const { return children[i]; }
|
||||||
|
|
||||||
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
||||||
size_t treeSize() const;
|
size_t treeSize() const;
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class ClusterTree {
|
||||||
virtual ~Cluster() {}
|
virtual ~Cluster() {}
|
||||||
|
|
||||||
const Cluster& operator[](size_t i) const {
|
const Cluster& operator[](size_t i) const {
|
||||||
return *(children[i]);
|
return *(children.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from factors associated with a single key
|
/// Construct from factors associated with a single key
|
||||||
|
@ -161,7 +161,7 @@ class ClusterTree {
|
||||||
}
|
}
|
||||||
|
|
||||||
const Cluster& operator[](size_t i) const {
|
const Cluster& operator[](size_t i) const {
|
||||||
return *(roots_[i]);
|
return *(roots_.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
@ -52,12 +52,12 @@ namespace gtsam {
|
||||||
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
* algorithms. Any factor graph holding eliminateable factors can derive from this class to
|
||||||
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
* expose functions for computing marginals, conditional marginals, doing multifrontal and
|
||||||
* sequential elimination, etc. */
|
* sequential elimination, etc. */
|
||||||
template<class FACTORGRAPH>
|
template<class FACTOR_GRAPH>
|
||||||
class EliminateableFactorGraph
|
class EliminateableFactorGraph
|
||||||
{
|
{
|
||||||
private:
|
private:
|
||||||
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class.
|
typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
|
||||||
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type
|
typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
|
||||||
// Base factor type stored in this graph (private because derived classes will get this from
|
// Base factor type stored in this graph (private because derived classes will get this from
|
||||||
// their FactorGraph base class)
|
// their FactorGraph base class)
|
||||||
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;
|
||||||
|
@ -139,7 +139,7 @@ namespace gtsam {
|
||||||
OptionalVariableIndex variableIndex = boost::none) const;
|
OptionalVariableIndex variableIndex = boost::none) const;
|
||||||
|
|
||||||
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
|
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
|
||||||
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
|
* provided, the ordering will be computed using either COLAMD or METIS, depending on
|
||||||
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
|
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
|
||||||
*
|
*
|
||||||
* <b> Example - Full Cholesky elimination in COLAMD order: </b>
|
* <b> Example - Full Cholesky elimination in COLAMD order: </b>
|
||||||
|
@ -160,7 +160,7 @@ namespace gtsam {
|
||||||
OptionalVariableIndex variableIndex = boost::none) const;
|
OptionalVariableIndex variableIndex = boost::none) const;
|
||||||
|
|
||||||
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
|
/** Do multifrontal elimination of all variables to produce a Bayes tree. If an ordering is not
|
||||||
* provided, the ordering will be computed using either COLAMD or METIS, dependeing on
|
* provided, the ordering will be computed using either COLAMD or METIS, depending on
|
||||||
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
|
* the parameter orderingType (Ordering::COLAMD or Ordering::METIS)
|
||||||
*
|
*
|
||||||
* <b> Example - Full QR elimination in specified order:
|
* <b> Example - Full QR elimination in specified order:
|
||||||
|
|
|
@ -104,6 +104,7 @@ class Ordering {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
Ordering();
|
Ordering();
|
||||||
Ordering(const gtsam::Ordering& other);
|
Ordering(const gtsam::Ordering& other);
|
||||||
|
Ordering(const std::vector<size_t>& keys);
|
||||||
|
|
||||||
template <
|
template <
|
||||||
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
|
||||||
|
@ -147,7 +148,7 @@ class Ordering {
|
||||||
|
|
||||||
// Standard interface
|
// Standard interface
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
size_t at(size_t key) const;
|
size_t at(size_t i) const;
|
||||||
void push_back(size_t key);
|
void push_back(size_t key);
|
||||||
|
|
||||||
// enabling serialization functionality
|
// enabling serialization functionality
|
||||||
|
@ -197,4 +198,15 @@ class VariableIndex {
|
||||||
size_t nEntries() const;
|
size_t nEntries() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/inference/Factor.h>
|
||||||
|
virtual class Factor {
|
||||||
|
void print(string s = "Factor\n", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
void printKeys(string s = "") const;
|
||||||
|
bool equals(const gtsam::Factor& other, double tol = 1e-9) const;
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeyVector keys() const;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -261,8 +261,7 @@ class VectorValues {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
virtual class GaussianFactor {
|
virtual class GaussianFactor : gtsam::Factor {
|
||||||
gtsam::KeyVector keys() const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
||||||
|
@ -273,8 +272,6 @@ virtual class GaussianFactor {
|
||||||
Matrix information() const;
|
Matrix information() const;
|
||||||
Matrix augmentedJacobian() const;
|
Matrix augmentedJacobian() const;
|
||||||
pair<Matrix, Vector> jacobian() const;
|
pair<Matrix, Vector> jacobian() const;
|
||||||
size_t size() const;
|
|
||||||
bool empty() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
@ -301,10 +298,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor {
|
||||||
//Testable
|
//Testable
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printKeys(string s) const;
|
|
||||||
gtsam::KeyVector& keys() const;
|
|
||||||
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
||||||
size_t size() const;
|
|
||||||
Vector unweighted_error(const gtsam::VectorValues& c) const;
|
Vector unweighted_error(const gtsam::VectorValues& c) const;
|
||||||
Vector error_vector(const gtsam::VectorValues& c) const;
|
Vector error_vector(const gtsam::VectorValues& c) const;
|
||||||
double error(const gtsam::VectorValues& c) const;
|
double error(const gtsam::VectorValues& c) const;
|
||||||
|
@ -346,10 +340,8 @@ virtual class HessianFactor : gtsam::GaussianFactor {
|
||||||
HessianFactor(const gtsam::GaussianFactorGraph& factors);
|
HessianFactor(const gtsam::GaussianFactorGraph& factors);
|
||||||
|
|
||||||
//Testable
|
//Testable
|
||||||
size_t size() const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printKeys(string s) const;
|
|
||||||
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
bool equals(const gtsam::GaussianFactor& lf, double tol) const;
|
||||||
double error(const gtsam::VectorValues& c) const;
|
double error(const gtsam::VectorValues& c) const;
|
||||||
|
|
||||||
|
|
|
@ -110,13 +110,10 @@ class NonlinearFactorGraph {
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
virtual class NonlinearFactor {
|
virtual class NonlinearFactor : gtsam::Factor {
|
||||||
// Factor base class
|
// Factor base class
|
||||||
size_t size() const;
|
|
||||||
gtsam::KeyVector keys() const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void printKeys(string s) const;
|
|
||||||
// NonlinearFactor
|
// NonlinearFactor
|
||||||
bool equals(const gtsam::NonlinearFactor& other, double tol) const;
|
bool equals(const gtsam::NonlinearFactor& other, double tol) const;
|
||||||
double error(const gtsam::Values& c) const;
|
double error(const gtsam::Values& c) const;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <iomanip>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -38,7 +39,8 @@ static DSFMapIndexPair generateDSF(const MatchIndicesMap& matches) {
|
||||||
// Image pair is (i1,i2).
|
// Image pair is (i1,i2).
|
||||||
size_t i1 = pair_indices.first;
|
size_t i1 = pair_indices.first;
|
||||||
size_t i2 = pair_indices.second;
|
size_t i2 = pair_indices.second;
|
||||||
for (size_t k = 0; k < corr_indices.rows(); k++) {
|
size_t m = static_cast<size_t>(corr_indices.rows());
|
||||||
|
for (size_t k = 0; k < m; k++) {
|
||||||
// Measurement indices are found in a single matrix row, as (k1,k2).
|
// Measurement indices are found in a single matrix row, as (k1,k2).
|
||||||
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
|
size_t k1 = corr_indices(k, 0), k2 = corr_indices(k, 1);
|
||||||
// Unique key for DSF is (i,k), representing keypoint index in an image.
|
// Unique key for DSF is (i,k), representing keypoint index in an image.
|
||||||
|
@ -128,7 +130,7 @@ std::vector<SfmTrack2d> tracksFromPairwiseMatches(
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(johnwlambert): return the Transitivity failure percentage here.
|
// TODO(johnwlambert): return the Transitivity failure percentage here.
|
||||||
return tracks2d;
|
return validTracks;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsfm
|
} // namespace gtsfm
|
||||||
|
|
|
@ -20,9 +20,10 @@
|
||||||
#include <gtsam/base/DSFMap.h>
|
#include <gtsam/base/DSFMap.h>
|
||||||
#include <gtsam/sfm/SfmTrack.h>
|
#include <gtsam/sfm/SfmTrack.h>
|
||||||
|
|
||||||
|
#include <boost/optional.hpp>
|
||||||
|
|
||||||
#include <Eigen/Core>
|
#include <Eigen/Core>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <optional>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
|
@ -65,4 +65,6 @@ namespace gtsam {
|
||||||
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
|
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// typedef for wrapper:
|
||||||
|
using SymbolicCluster = SymbolicJunctionTree::Cluster;
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicFactor.h>
|
#include <gtsam/symbolic/SymbolicFactor.h>
|
||||||
virtual class SymbolicFactor {
|
virtual class SymbolicFactor : gtsam::Factor {
|
||||||
// Standard Constructors and Named Constructors
|
// Standard Constructors and Named Constructors
|
||||||
SymbolicFactor(const gtsam::SymbolicFactor& f);
|
SymbolicFactor(const gtsam::SymbolicFactor& f);
|
||||||
SymbolicFactor();
|
SymbolicFactor();
|
||||||
|
@ -18,12 +18,10 @@ virtual class SymbolicFactor {
|
||||||
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
|
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
|
||||||
|
|
||||||
// From Factor
|
// From Factor
|
||||||
size_t size() const;
|
|
||||||
void print(string s = "SymbolicFactor",
|
void print(string s = "SymbolicFactor",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::SymbolicFactor& other, double tol) const;
|
bool equals(const gtsam::SymbolicFactor& other, double tol) const;
|
||||||
gtsam::KeyVector keys();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicFactorGraph.h>
|
#include <gtsam/symbolic/SymbolicFactorGraph.h>
|
||||||
|
@ -139,7 +137,60 @@ class SymbolicBayesNet {
|
||||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/symbolic/SymbolicEliminationTree.h>
|
||||||
|
|
||||||
|
class SymbolicEliminationTree {
|
||||||
|
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
|
||||||
|
const gtsam::VariableIndex& structure,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
|
||||||
|
const gtsam::Ordering& order);
|
||||||
|
|
||||||
|
void print(
|
||||||
|
string name = "EliminationTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::SymbolicEliminationTree& other,
|
||||||
|
double tol = 1e-9) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/symbolic/SymbolicJunctionTree.h>
|
||||||
|
|
||||||
|
class SymbolicCluster {
|
||||||
|
gtsam::Ordering orderedFrontalKeys;
|
||||||
|
gtsam::SymbolicFactorGraph factors;
|
||||||
|
const gtsam::SymbolicCluster& operator[](size_t i) const;
|
||||||
|
size_t nrChildren() const;
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SymbolicJunctionTree {
|
||||||
|
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
|
||||||
|
void print(
|
||||||
|
string name = "JunctionTree: ",
|
||||||
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
|
size_t nrRoots() const;
|
||||||
|
const gtsam::SymbolicCluster& operator[](size_t i) const;
|
||||||
|
};
|
||||||
|
|
||||||
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
#include <gtsam/symbolic/SymbolicBayesTree.h>
|
||||||
|
|
||||||
|
class SymbolicBayesTreeClique {
|
||||||
|
SymbolicBayesTreeClique();
|
||||||
|
SymbolicBayesTreeClique(const gtsam::SymbolicConditional* conditional);
|
||||||
|
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
|
||||||
|
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter);
|
||||||
|
const gtsam::SymbolicConditional* conditional() const;
|
||||||
|
bool isRoot() const;
|
||||||
|
gtsam::SymbolicBayesTreeClique* parent() const;
|
||||||
|
size_t treeSize() const;
|
||||||
|
size_t numCachedSeparatorMarginals() const;
|
||||||
|
void deleteCachedShortcuts();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class SymbolicBayesTree {
|
class SymbolicBayesTree {
|
||||||
// Constructors
|
// Constructors
|
||||||
SymbolicBayesTree();
|
SymbolicBayesTree();
|
||||||
|
@ -151,9 +202,14 @@ class SymbolicBayesTree {
|
||||||
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
|
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
|
||||||
|
|
||||||
// Standard Interface
|
// Standard Interface
|
||||||
// size_t findParentClique(const gtsam::IndexVector& parents) const;
|
bool empty() const;
|
||||||
size_t size();
|
size_t size() const;
|
||||||
void saveGraph(string s) const;
|
|
||||||
|
const gtsam::SymbolicBayesTreeClique* operator[](size_t j) const;
|
||||||
|
|
||||||
|
void saveGraph(string s,
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
void clear();
|
void clear();
|
||||||
void deleteCachedShortcuts();
|
void deleteCachedShortcuts();
|
||||||
size_t numCachedSeparatorMarginals() const;
|
size_t numCachedSeparatorMarginals() const;
|
||||||
|
@ -161,28 +217,9 @@ class SymbolicBayesTree {
|
||||||
gtsam::SymbolicConditional* marginalFactor(size_t key) const;
|
gtsam::SymbolicConditional* marginalFactor(size_t key) const;
|
||||||
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
|
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
|
||||||
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
|
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
|
||||||
};
|
|
||||||
|
|
||||||
class SymbolicBayesTreeClique {
|
string dot(const gtsam::KeyFormatter& keyFormatter =
|
||||||
SymbolicBayesTreeClique();
|
|
||||||
// SymbolicBayesTreeClique(gtsam::sharedConditional* conditional);
|
|
||||||
|
|
||||||
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
|
|
||||||
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
size_t numCachedSeparatorMarginals() const;
|
|
||||||
// gtsam::sharedConditional* conditional() const;
|
|
||||||
bool isRoot() const;
|
|
||||||
size_t treeSize() const;
|
|
||||||
gtsam::SymbolicBayesTreeClique* parent() const;
|
|
||||||
|
|
||||||
// // TODO: need wrapped versions graphs, BayesNet
|
|
||||||
// BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function)
|
|
||||||
// const; FactorGraph<FactorType> marginal(derived_ptr root, Eliminate
|
|
||||||
// function) const; FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr
|
|
||||||
// root, Eliminate function) const;
|
|
||||||
//
|
|
||||||
void deleteCachedShortcuts();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -25,7 +25,13 @@ class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
self.B = (5, 2)
|
self.B = (5, 2)
|
||||||
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
|
||||||
|
|
||||||
|
def test_from_floats(self):
|
||||||
|
"""Test whether we can construct a factor from floats."""
|
||||||
|
actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.])
|
||||||
|
self.gtsamAssertEquals(actual, self.factor)
|
||||||
|
|
||||||
def test_enumerate(self):
|
def test_enumerate(self):
|
||||||
|
"""Test whether we can enumerate the factor."""
|
||||||
actual = self.factor.enumerate()
|
actual = self.factor.enumerate()
|
||||||
_, values = zip(*actual)
|
_, values = zip(*actual)
|
||||||
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
|
||||||
|
|
|
@ -13,10 +13,15 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
import numpy as np
|
||||||
DiscreteConditional, DiscreteFactorGraph, Ordering)
|
from gtsam.symbol_shorthand import A, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
import gtsam
|
||||||
|
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||||
|
DiscreteConditional, DiscreteFactorGraph,
|
||||||
|
DiscreteValues, Ordering)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
@ -65,15 +70,105 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
# bayesTree[key].printSignature()
|
# bayesTree[key].printSignature()
|
||||||
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
# bayesTree.saveGraph("test_DiscreteBayesTree.dot")
|
||||||
|
|
||||||
self.assertFalse(bayesTree.empty())
|
|
||||||
self.assertEqual(12, bayesTree.size())
|
|
||||||
|
|
||||||
# The root is P( 8 12 14), we can retrieve it by key:
|
# The root is P( 8 12 14), we can retrieve it by key:
|
||||||
root = bayesTree[8]
|
root = bayesTree[8]
|
||||||
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
self.assertIsInstance(root, DiscreteBayesTreeClique)
|
||||||
self.assertTrue(root.isRoot())
|
self.assertTrue(root.isRoot())
|
||||||
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
self.assertIsInstance(root.conditional(), DiscreteConditional)
|
||||||
|
|
||||||
|
# Test all methods in DiscreteBayesTree
|
||||||
|
self.gtsamAssertEquals(bayesTree, bayesTree)
|
||||||
|
|
||||||
|
# Check value at 0
|
||||||
|
zero_values = DiscreteValues()
|
||||||
|
for j in range(15):
|
||||||
|
zero_values[j] = 0
|
||||||
|
value_at_zeros = bayesTree.evaluate(zero_values)
|
||||||
|
self.assertAlmostEqual(value_at_zeros, 0.0)
|
||||||
|
|
||||||
|
# Check value at max
|
||||||
|
values_star = factorGraph.optimize()
|
||||||
|
max_value = bayesTree.evaluate(values_star)
|
||||||
|
self.assertAlmostEqual(max_value, 0.002548)
|
||||||
|
|
||||||
|
# Check operator sugar
|
||||||
|
max_value = bayesTree(values_star)
|
||||||
|
self.assertAlmostEqual(max_value, 0.002548)
|
||||||
|
|
||||||
|
self.assertFalse(bayesTree.empty())
|
||||||
|
self.assertEqual(12, bayesTree.size())
|
||||||
|
|
||||||
|
@unittest.skip("TODO: segfaults on gcc 7 and gcc 9")
|
||||||
|
def test_discrete_bayes_tree_lookup(self):
|
||||||
|
"""Check that we can have a multi-frontal lookup table."""
|
||||||
|
# Make a small planning-like graph: 3 states, 2 actions
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
|
||||||
|
a1, a2 = (A(1), 2), (A(2), 2)
|
||||||
|
|
||||||
|
# Constraint on start and goal
|
||||||
|
graph.add([x1], np.array([1, 0, 0]))
|
||||||
|
graph.add([x3], np.array([0, 0, 1]))
|
||||||
|
|
||||||
|
# Should I stay or should I go?
|
||||||
|
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
|
||||||
|
r = 10
|
||||||
|
table = np.array([
|
||||||
|
r, 0, 0, 0, r, 0, # x1 = 0
|
||||||
|
0, r, 0, 0, 0, r, # x1 = 1
|
||||||
|
0, 0, r, 0, 0, r # x1 = 2
|
||||||
|
])
|
||||||
|
graph.add([x1, a1, x2], table)
|
||||||
|
graph.add([x2, a2, x3], table)
|
||||||
|
|
||||||
|
# print(graph) will give:
|
||||||
|
# size: 4
|
||||||
|
# factor 0: f[ (x1,3), ] ...
|
||||||
|
# factor 1: f[ (x3,3), ] ...
|
||||||
|
# factor 2: f[ (x1,3), (a1,2), (x2,3), ] ...
|
||||||
|
# factor 3: f[ (x2,3), (a2,2), (x3,3), ] ...
|
||||||
|
|
||||||
|
# Eliminate for MPE (maximum probable explanation).
|
||||||
|
ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
|
||||||
|
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
|
||||||
|
|
||||||
|
# print(lookup) will give:
|
||||||
|
# DiscreteBayesTree
|
||||||
|
# : cliques: 2, variables: 5
|
||||||
|
# - g( x1 a1 x2 ): ...
|
||||||
|
# | - g( a2 x3 ; x2 ): ...
|
||||||
|
|
||||||
|
# Check that the lookup table is correct
|
||||||
|
assert lookup.size() == 2
|
||||||
|
lookup_x1_a1_x2 = lookup[X(1)].conditional()
|
||||||
|
assert lookup_x1_a1_x2.nrFrontals() == 3
|
||||||
|
# Check that sum is 100
|
||||||
|
empty = gtsam.DiscreteValues()
|
||||||
|
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 100)
|
||||||
|
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[X(1)] = 0
|
||||||
|
values[A(1)] = 1
|
||||||
|
values[X(2)] = 1
|
||||||
|
self.assertAlmostEqual(lookup_x1_a1_x2(values), 100)
|
||||||
|
|
||||||
|
lookup_a2_x3 = lookup[X(3)].conditional()
|
||||||
|
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
|
||||||
|
sum_x2 = lookup_a2_x3.sum(2)
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[X(2)] = 0
|
||||||
|
self.assertAlmostEqual(sum_x2(values), 0)
|
||||||
|
values[X(2)] = 1
|
||||||
|
self.assertAlmostEqual(sum_x2(values), 10)
|
||||||
|
values[X(2)] = 2
|
||||||
|
self.assertAlmostEqual(sum_x2(values), 20)
|
||||||
|
assert lookup_a2_x3.nrFrontals() == 2
|
||||||
|
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
|
||||||
|
values = DiscreteValues()
|
||||||
|
values[X(2)] = 1
|
||||||
|
values[A(2)] = 1
|
||||||
|
values[X(3)] = 2
|
||||||
|
self.assertAlmostEqual(lookup_a2_x3(values), 10)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -4,18 +4,44 @@ Authors: John Lambert
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import gtsam
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
|
|
||||||
SfmMeasurementVector, SfmTrack2d)
|
|
||||||
from gtsam.gtsfm import Keypoints
|
from gtsam.gtsfm import Keypoints
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
import gtsam
|
||||||
|
from gtsam import (IndexPair, KeypointsVector, MatchIndicesMap, Point2,
|
||||||
|
SfmMeasurementVector, SfmTrack2d)
|
||||||
|
|
||||||
|
|
||||||
class TestDsfTrackGenerator(GtsamTestCase):
|
class TestDsfTrackGenerator(GtsamTestCase):
|
||||||
"""Tests for DsfTrackGenerator."""
|
"""Tests for DsfTrackGenerator."""
|
||||||
|
|
||||||
|
def test_generate_tracks_from_pairwise_matches_nontransitive(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Tests DSF for non-transitive matches.
|
||||||
|
|
||||||
|
Test will result in no tracks since nontransitive tracks are naively
|
||||||
|
discarded by DSF.
|
||||||
|
"""
|
||||||
|
keypoints = get_dummy_keypoints_list()
|
||||||
|
nontransitive_matches = get_nontransitive_matches()
|
||||||
|
|
||||||
|
# For each image pair (i1,i2), we provide a (K,2) matrix
|
||||||
|
# of corresponding keypoint indices (k1,k2).
|
||||||
|
matches = MatchIndicesMap()
|
||||||
|
for (i1, i2), correspondences in nontransitive_matches.items():
|
||||||
|
matches[IndexPair(i1, i2)] = correspondences
|
||||||
|
|
||||||
|
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
|
||||||
|
matches,
|
||||||
|
keypoints,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(tracks), 0, "Tracks not filtered correctly")
|
||||||
|
|
||||||
def test_track_generation(self) -> None:
|
def test_track_generation(self) -> None:
|
||||||
"""Ensures that DSF generates three tracks from measurements
|
"""Ensures that DSF generates three tracks from measurements
|
||||||
in 3 images (H=200,W=400)."""
|
in 3 images (H=200,W=400)."""
|
||||||
|
@ -23,20 +49,20 @@ class TestDsfTrackGenerator(GtsamTestCase):
|
||||||
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
|
kps_i1 = Keypoints(np.array([[50.0, 60], [70, 80], [90, 100]]))
|
||||||
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
|
kps_i2 = Keypoints(np.array([[110.0, 120], [130, 140]]))
|
||||||
|
|
||||||
keypoints_list = KeypointsVector()
|
keypoints = KeypointsVector()
|
||||||
keypoints_list.append(kps_i0)
|
keypoints.append(kps_i0)
|
||||||
keypoints_list.append(kps_i1)
|
keypoints.append(kps_i1)
|
||||||
keypoints_list.append(kps_i2)
|
keypoints.append(kps_i2)
|
||||||
|
|
||||||
# For each image pair (i1,i2), we provide a (K,2) matrix
|
# For each image pair (i1,i2), we provide a (K,2) matrix
|
||||||
# of corresponding image indices (k1,k2).
|
# of corresponding image indices (k1,k2).
|
||||||
matches_dict = MatchIndicesMap()
|
matches = MatchIndicesMap()
|
||||||
matches_dict[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
|
matches[IndexPair(0, 1)] = np.array([[0, 0], [1, 1]])
|
||||||
matches_dict[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
|
matches[IndexPair(1, 2)] = np.array([[2, 0], [1, 1]])
|
||||||
|
|
||||||
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
|
tracks = gtsam.gtsfm.tracksFromPairwiseMatches(
|
||||||
matches_dict,
|
matches,
|
||||||
keypoints_list,
|
keypoints,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
assert len(tracks) == 3
|
assert len(tracks) == 3
|
||||||
|
@ -93,5 +119,71 @@ class TestSfmTrack2d(GtsamTestCase):
|
||||||
assert track.numberMeasurements() == 1
|
assert track.numberMeasurements() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_dummy_keypoints_list() -> KeypointsVector:
|
||||||
|
"""Generate a list of dummy keypoints for testing."""
|
||||||
|
img1_kp_coords = np.array([[1, 1], [2, 2], [3, 3.]])
|
||||||
|
img2_kp_coords = np.array(
|
||||||
|
[
|
||||||
|
[1, 1.],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3],
|
||||||
|
[4, 4],
|
||||||
|
[5, 5],
|
||||||
|
[6, 6],
|
||||||
|
[7, 7],
|
||||||
|
[8, 8],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
img3_kp_coords = np.array(
|
||||||
|
[
|
||||||
|
[1, 1.],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3],
|
||||||
|
[4, 4],
|
||||||
|
[5, 5],
|
||||||
|
[6, 6],
|
||||||
|
[7, 7],
|
||||||
|
[8, 8],
|
||||||
|
[9, 9],
|
||||||
|
[10, 10],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
img4_kp_coords = np.array(
|
||||||
|
[
|
||||||
|
[1, 1.],
|
||||||
|
[2, 2],
|
||||||
|
[3, 3],
|
||||||
|
[4, 4],
|
||||||
|
[5, 5],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
keypoints = KeypointsVector()
|
||||||
|
keypoints.append(Keypoints(coordinates=img1_kp_coords))
|
||||||
|
keypoints.append(Keypoints(coordinates=img2_kp_coords))
|
||||||
|
keypoints.append(Keypoints(coordinates=img3_kp_coords))
|
||||||
|
keypoints.append(Keypoints(coordinates=img4_kp_coords))
|
||||||
|
return keypoints
|
||||||
|
|
||||||
|
|
||||||
|
def get_nontransitive_matches() -> Dict[Tuple[int, int], np.ndarray]:
|
||||||
|
"""Set up correspondences for each (i1,i2) pair that violates transitivity.
|
||||||
|
|
||||||
|
(i=0, k=0) (i=0, k=1)
|
||||||
|
| \\ |
|
||||||
|
| \\ |
|
||||||
|
(i=1, k=2)--(i=2,k=3)--(i=3, k=4)
|
||||||
|
|
||||||
|
Transitivity is violated due to the match between frames 0 and 3.
|
||||||
|
"""
|
||||||
|
nontransitive_matches = {
|
||||||
|
(0, 1): np.array([[0, 2]]),
|
||||||
|
(1, 2): np.array([[2, 3]]),
|
||||||
|
(0, 2): np.array([[0, 3]]),
|
||||||
|
(0, 3): np.array([[1, 4]]),
|
||||||
|
(2, 3): np.array([[3, 4]]),
|
||||||
|
}
|
||||||
|
return nontransitive_matches
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue