Merge branch 'develop' into fix/windows-tests

release/4.3a0
Varun Agrawal 2023-07-27 12:06:12 -04:00
commit e4ff39cd42
61 changed files with 1340 additions and 424 deletions

View File

@ -105,34 +105,52 @@ jobs:
cmake -G Ninja -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib" cmake -G Ninja -B build -S . -DGTSAM_BUILD_EXAMPLES_ALWAYS=OFF -DBOOST_ROOT="${env:BOOST_ROOT}" -DBOOST_INCLUDEDIR="${env:BOOST_ROOT}\boost\include" -DBOOST_LIBRARYDIR="${env:BOOST_ROOT}\lib"
- name: Build - name: Build
shell: bash
run: | run: |
# Since Visual Studio is a multi-generator, we need to use --config # Since Visual Studio is a multi-generator, we need to use --config
# https://stackoverflow.com/a/24470998/1236990 # https://stackoverflow.com/a/24470998/1236990
cmake --build build -j4 --config ${{ matrix.build_type }} --target gtsam cmake --build build -j4 --config ${{ matrix.build_type }} --target gtsam
cmake --build build -j4 --config ${{ matrix.build_type }} --target gtsam_unstable cmake --build build -j4 --config ${{ matrix.build_type }} --target gtsam_unstable
cmake --build build -j4 --config ${{ matrix.build_type }} --target wrap
# Target doesn't exist
# cmake --build build -j4 --config ${{ matrix.build_type }} --target wrap
- name: Test
shell: bash
run: |
# Run GTSAM tests # Run GTSAM tests
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.base cmake --build build -j4 --config ${{ matrix.build_type }} --target check.base
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.basis cmake --build build -j4 --config ${{ matrix.build_type }} --target check.basis
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.discrete cmake --build build -j4 --config ${{ matrix.build_type }} --target check.discrete
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.geometry # Compilation error
# cmake --build build -j4 --config ${{ matrix.build_type }} --target check.geometry
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.inference cmake --build build -j4 --config ${{ matrix.build_type }} --target check.inference
# Compile. Fail with exception
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.linear cmake --build build -j4 --config ${{ matrix.build_type }} --target check.linear
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.navigation cmake --build build -j4 --config ${{ matrix.build_type }} --target check.navigation
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.sam cmake --build build -j4 --config ${{ matrix.build_type }} --target check.sam
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.sfm cmake --build build -j4 --config ${{ matrix.build_type }} --target check.sfm
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.symbolic cmake --build build -j4 --config ${{ matrix.build_type }} --target check.symbolic
# Compile. Fail with exception
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.hybrid cmake --build build -j4 --config ${{ matrix.build_type }} --target check.hybrid
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.nonlinear # Compile. Fail with exception
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.slam # cmake --build build -j4 --config ${{ matrix.build_type }} --target check.nonlinear
# Compilation error
# cmake --build build -j4 --config ${{ matrix.build_type }} --target check.slam
# Run GTSAM_UNSTABLE tests # Run GTSAM_UNSTABLE tests
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.base_unstable cmake --build build -j4 --config ${{ matrix.build_type }} --target check.base_unstable
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.geometry_unstable # Compile. Fail with exception
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.linear_unstable # cmake --build build -j4 --config ${{ matrix.build_type }} --target check.geometry_unstable
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.discrete_unstable # Compile. Fail with exception
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.dynamics_unstable # cmake --build build -j4 --config ${{ matrix.build_type }} --target check.linear_unstable
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.nonlinear_unstable # Compile. Fail with exception
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.slam_unstable # cmake --build build -j4 --config ${{ matrix.build_type }} --target check.discrete_unstable
cmake --build build -j4 --config ${{ matrix.build_type }} --target check.partition # Compile. Fail with exception
# cmake --build build -j4 --config ${{ matrix.build_type }} --target check.dynamics_unstable
# Compile. Fail with exception
# cmake --build build -j4 --config ${{ matrix.build_type }} --target check.nonlinear_unstable
# Compilation error
# cmake --build build -j4 --config ${{ matrix.build_type }} --target check.slam_unstable
# Compilation error
# cmake --build build -j4 --config ${{ matrix.build_type }} --target check.partition

View File

@ -32,6 +32,14 @@ set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
############################################################################### ###############################################################################
# Gather information, perform checks, set defaults # Gather information, perform checks, set defaults
if(MSVC)
set(MSVC_LINKER_FLAGS "/FORCE:MULTIPLE")
set(CMAKE_EXE_LINKER_FLAGS ${MSVC_LINKER_FLAGS})
set(CMAKE_MODULE_LINKER_FLAGS ${MSVC_LINKER_FLAGS})
set(CMAKE_SHARED_LINKER_FLAGS ${MSVC_LINKER_FLAGS})
set(CMAKE_STATIC_LINKER_FLAGS ${MSVC_LINKER_FLAGS})
endif()
set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
include(GtsamMakeConfigFile) include(GtsamMakeConfigFile)
include(GNUInstallDirs) include(GNUInstallDirs)

View File

@ -29,9 +29,9 @@
namespace gtsam { namespace gtsam {
/** /**
* Algebraic Decision Trees fix the range to double * An algebraic decision tree fixes the range of a DecisionTree to double.
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar.
* TODO: consider eliminating this class altogether? * TODO(dellaert): consider eliminating this class altogether?
* *
* @ingroup discrete * @ingroup discrete
*/ */
@ -81,20 +81,62 @@ namespace gtsam {
AlgebraicDecisionTree(const L& label, double y1, double y2) AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {} : Base(label, y1, y2) {}
/** Create a new leaf function splitting on a variable */ /**
* @brief Create a new leaf function splitting on a variable
*
* @param labelC: The label with cardinality 2
* @param y1: The value for the first key
* @param y2: The value for the second key
*
* Example:
* @code{.cpp}
* std::pair<string, size_t> A {"a", 2};
* AlgebraicDecisionTree<string> a(A, 0.6, 0.4);
* @endcode
*/
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2) double y2)
: Base(labelC, y1, y2) {} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /**
* @brief Create from keys with cardinalities and a vector table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param ys: The vector table
*
* Example with three keys, A, B, and C, with cardinalities 2, 3, and 2,
* respectively, and a vector table of size 12:
* @code{.cpp}
* DiscreteKey A(0, 2), B(1, 3), C(2, 2);
* const vector<double> cpt{
* 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
* 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10};
* AlgebraicDecisionTree<Key> expected(A & B & C, cpt);
* @endcode
* The table is given in the following order:
* A=0, B=0, C=0
* A=0, B=0, C=1
* ...
* A=1, B=1, C=1
* Hence, the first line in the table is for A==0, and the second for A==1.
* In each line, the first two entries are for B==0, the next two for B==1,
* and the last two for B==2. Each pair is for a C value of 0 and 1.
*/
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, (const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) { const std::vector<double>& ys) {
this->root_ = this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /**
* @brief Create from keys and string table
*
* @param labelCs: The keys, with cardinalities, given as pairs
* @param table: The string table, given as a string of doubles.
*
* @note Table needs to be in same order as the vector table in the other constructor.
*/
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) { const std::string& table) {
@ -109,7 +151,13 @@ namespace gtsam {
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /**
* @brief Create a range of decision trees, splitting on a single variable.
*
* @param begin: Iterator to beginning of a range of decision trees
* @param end: Iterator to end of a range of decision trees
* @param label: The label to split on
*/
template <typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) { : Base(nullptr) {

View File

@ -93,7 +93,8 @@ namespace gtsam {
/// print /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; std::cout << s << " Leaf [" << nrAssignments() << "] "
<< valueFormatter(constant_) << std::endl;
} }
/** Write graphviz format to stream `os`. */ /** Write graphviz format to stream `os`. */
@ -626,7 +627,7 @@ namespace gtsam {
// B=1 // B=1
// A=0: 3 // A=0: 3
// A=1: 4 // A=1: 4
// Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce // Note, through the magic of "compose", create([A B],[1 3 2 4]) will produce
// exactly the same tree as above: the highest label is always the root. // exactly the same tree as above: the highest label is always the root.
// However, it will be *way* faster if labels are given highest to lowest. // However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y> template<typename L, typename Y>
@ -827,6 +828,16 @@ namespace gtsam {
return total; return total;
} }
/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}
/****************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>

View File

@ -39,9 +39,23 @@
namespace gtsam { namespace gtsam {
/** /**
* Decision Tree * @brief a decision tree is a function from assignments to values.
* L = label for variables * @tparam L label for variables
* Y = function range (any algebra), e.g., bool, int, double * @tparam Y function range (any algebra), e.g., bool, int, double
*
* After creating a decision tree on some variables, the tree can be evaluated
* on an assignment to those variables. Example:
*
* @code{.cpp}
* // Create a decision stump one one variable 'a' with values 10 and 20.
* DecisionTree<char, int> tree('a', 10, 20);
*
* // Evaluate the tree on an assignment to the variable.
* int value0 = tree({{'a', 0}}); // value0 = 10
* int value1 = tree({{'a', 1}}); // value1 = 20
* @endcode
*
* More examples can be found in testDecisionTree.cpp
* *
* @ingroup discrete * @ingroup discrete
*/ */
@ -136,7 +150,8 @@ namespace gtsam {
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities, /**
* Internal recursive function to create from keys, cardinalities,
* and Y values * and Y values
*/ */
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
@ -167,7 +182,13 @@ namespace gtsam {
/** Create a constant */ /** Create a constant */
explicit DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` /**
* @brief Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
*
* @param label The variable to split on.
* @param y1 The value for the first assignment.
* @param y2 The value for the second assignment.
*/
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */ /** Allow Label+Cardinality for convenience */
@ -299,6 +320,42 @@ namespace gtsam {
/// Return the number of leaves in the tree. /// Return the number of leaves in the tree.
size_t nrLeaves() const; size_t nrLeaves() const;
/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;
/** /**
* @brief Fold a binary function over the tree, returning accumulator. * @brief Fold a binary function over the tree, returning accumulator.
* *

View File

@ -101,6 +101,14 @@ namespace gtsam {
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {

View File

@ -59,11 +59,46 @@ namespace gtsam {
/** Constructor from DiscreteKeys and AlgebraicDecisionTree */ /** Constructor from DiscreteKeys and AlgebraicDecisionTree */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /**
* @brief Constructor from doubles
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* const std::vector<double> table {2, 5, 3, 6, 4, 7};
* DecisionTreeFactor f1({X, Y}, table);
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table); const std::vector<double>& table);
/** Constructor from string */ /**
* @brief Constructor from string
*
* @param keys The discrete keys.
* @param table The table of values.
*
* @throw std::invalid_argument if the size of `table` does not match the
* number of assignments.
*
* Example:
* @code{.cpp}
* DiscreteKey X(0,2), Y(1,3);
* DecisionTreeFactor factor({X, Y}, "2 5 3 6 4 7");
* @endcode
*
* The values in the table should be laid out so that the first key varies
* the slowest, and the last key the fastest.
*/
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization /// Single-key specialization
@ -147,6 +182,12 @@ namespace gtsam {
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
/** /**
* Apply binary operator (*this) "op" f * Apply binary operator (*this) "op" f
* @param f the second argument for op * @param f the second argument for op

View File

@ -58,6 +58,11 @@ class GTSAM_EXPORT DiscreteBayesTreeClique
//** evaluate conditional probability of subtree for given DiscreteValues */ //** evaluate conditional probability of subtree for given DiscreteValues */
double evaluate(const DiscreteValues& values) const; double evaluate(const DiscreteValues& values) const;
//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
}; };
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -42,15 +42,29 @@ class DiscreteJunctionTree;
/** /**
* @brief Main elimination function for DiscreteFactorGraph. * @brief Main elimination function for DiscreteFactorGraph.
* *
* @param factors * @param factors The factor graph to eliminate.
* @param keys * @param frontalKeys An ordering for which variables to eliminate.
* @return GTSAM_EXPORT * @return A pair of the resulting conditional and the separator factor.
* @ingroup discrete * @ingroup discrete
*/ */
GTSAM_EXPORT std::pair<std::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr> GTSAM_EXPORT
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys); std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/**
* @brief Alternate elimination function for that creates non-normalized lookup tables.
*
* @param factors The factor graph to eliminate.
* @param frontalKeys An ordering for which variables to eliminate.
* @return A pair of the resulting lookup table and the separator factor.
* @ingroup discrete
*/
GTSAM_EXPORT
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr>
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph> template<> struct EliminationTraits<DiscreteFactorGraph>
{ {
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
@ -60,12 +74,14 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function /// The default dense elimination function
static std::pair<std::shared_ptr<ConditionalType>, static std::pair<std::shared_ptr<ConditionalType>,
std::shared_ptr<FactorType> > std::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); return EliminateDiscrete(factors, keys);
} }
/// The default ordering generation function /// The default ordering generation function
static Ordering DefaultOrderingFunc( static Ordering DefaultOrderingFunc(
const FactorGraphType& graph, const FactorGraphType& graph,
@ -74,7 +90,6 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
} }
}; };
/* ************************************************************************* */
/** /**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e. * A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor * Factor == DiscreteFactor
@ -108,8 +123,8 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Implicit copy/downcast constructor to override explicit template container /** Implicit copy/downcast constructor to override explicit template container
* constructor */ * constructor */
template <class DERIVEDFACTOR> template <class DERIVED_FACTOR>
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {} DiscreteFactorGraph(const FactorGraph<DERIVED_FACTOR>& graph) : Base(graph) {}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -227,10 +242,6 @@ class GTSAM_EXPORT DiscreteFactorGraph
/// @} /// @}
}; // \ DiscreteFactorGraph }; // \ DiscreteFactorGraph
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys);
/// traits /// traits
template <> template <>
struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {}; struct traits<DiscreteFactorGraph> : public Testable<DiscreteFactorGraph> {};

View File

@ -66,4 +66,6 @@ namespace gtsam {
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree); DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
}; };
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
} }

View File

@ -120,6 +120,11 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
/// @} /// @}
}; };
/// Free version of CartesianProduct.
inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
return DiscreteValues::CartesianProduct(keys);
}
/// Free version of markdown. /// Free version of markdown.
std::string markdown(const DiscreteValues& values, std::string markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,

View File

@ -23,6 +23,8 @@
#include <vector> #include <vector>
#include <gtsam/dllexport.h> #include <gtsam/dllexport.h>
#include <gtsam/dllexport.h>
namespace gtsam { namespace gtsam {
/** /**
* @brief A simple parser that replaces the boost spirit parser. * @brief A simple parser that replaces the boost spirit parser.

View File

@ -64,7 +64,7 @@ TableFactor::TableFactor(const DiscreteConditional& c)
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const std::vector<double>& table) {
Eigen::SparseVector<double> sparse_table(table.size()); Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserving the space. // Count number of nonzero elements in table and reserve the space.
const uint64_t nnz = std::count_if(table.begin(), table.end(), const uint64_t nnz = std::count_if(table.begin(), table.end(),
[](uint64_t i) { return i != 0; }); [](uint64_t i) { return i != 0; });
sparse_table.reserve(nnz); sparse_table.reserve(nnz);
@ -218,6 +218,45 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
} }
/* ************************************************************************ */
TableFactor TableFactor::apply(Unary op) const {
// Initialize new factor.
uint64_t cardi = 1;
for (auto [key, c] : cardinalities_) cardi *= c;
Eigen::SparseVector<double> sparse_table(cardi);
sparse_table.reserve(sparse_table_.nonZeros());
// Populate
for (SparseIt it(sparse_table_); it; ++it) {
sparse_table.coeffRef(it.index()) = op(it.value());
}
// Free unused memory and return.
sparse_table.pruned();
sparse_table.data().squeeze();
return TableFactor(discreteKeys(), sparse_table);
}
/* ************************************************************************ */
TableFactor TableFactor::apply(UnaryAssignment op) const {
// Initialize new factor.
uint64_t cardi = 1;
for (auto [key, c] : cardinalities_) cardi *= c;
Eigen::SparseVector<double> sparse_table(cardi);
sparse_table.reserve(sparse_table_.nonZeros());
// Populate
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
sparse_table.coeffRef(it.index()) = op(assignment, it.value());
}
// Free unused memory and return.
sparse_table.pruned();
sparse_table.data().squeeze();
return TableFactor(discreteKeys(), sparse_table);
}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
if (keys_.empty() && sparse_table_.nonZeros() == 0) if (keys_.empty() && sparse_table_.nonZeros() == 0)

View File

@ -93,6 +93,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef std::shared_ptr<TableFactor> shared_ptr; typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt; typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList; typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Unary = std::function<double(const double&)>;
using UnaryAssignment =
std::function<double(const Assignment<Key>&, const double&)>;
using Binary = std::function<double(const double, const double)>; using Binary = std::function<double(const double, const double)>;
public: public:
@ -218,6 +221,18 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/**
* Apply unary operator `op(*this)` where `op` accepts the discrete value.
* @param op a unary operator that operates on TableFactor
*/
TableFactor apply(Unary op) const;
/**
* Apply unary operator `op(*this)` where `op` accepts the discrete assignment
* and the value at that assignment.
* @param op a unary operator that operates on TableFactor
*/
TableFactor apply(UnaryAssignment op) const;
/** /**
* Apply binary operator (*this) "op" f * Apply binary operator (*this) "op" f
* @param f the second argument for op * @param f the second argument for op
@ -225,10 +240,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/ */
TableFactor apply(const TableFactor& f, Binary op) const; TableFactor apply(const TableFactor& f, Binary op) const;
/// Return keys in contract mode. /**
* Return keys in contract mode.
*
* Modes are each of the dimensions of a sparse tensor,
* and the contract modes represent which dimensions will
* be involved in contraction (aka tensor multiplication).
*/
DiscreteKeys contractDkeys(const TableFactor& f) const; DiscreteKeys contractDkeys(const TableFactor& f) const;
/// Return keys in free mode. /**
* @brief Return keys in free mode which are the dimensions
* not involved in the contraction operation.
*/
DiscreteKeys freeDkeys(const TableFactor& f) const; DiscreteKeys freeDkeys(const TableFactor& f) const;
/// Return union of DiscreteKeys in two factors. /// Return union of DiscreteKeys in two factors.

View File

@ -17,6 +17,8 @@ class DiscreteKeys {
}; };
// DiscreteValues is added in specializations/discrete.h as a std::map // DiscreteValues is added in specializations/discrete.h as a std::map
std::vector<gtsam::DiscreteValues> cartesianProduct(
const gtsam::DiscreteKeys& keys);
string markdown( string markdown(
const gtsam::DiscreteValues& values, const gtsam::DiscreteValues& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
@ -31,13 +33,11 @@ string html(const gtsam::DiscreteValues& values,
std::map<gtsam::Key, std::vector<std::string>> names); std::map<gtsam::Key, std::vector<std::string>> names);
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
class DiscreteFactor { virtual class DiscreteFactor : gtsam::Factor {
void print(string s = "DiscreteFactor\n", void print(string s = "DiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
}; };
@ -49,7 +49,12 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
const std::vector<double>& spec); const std::vector<double>& spec);
DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); DecisionTreeFactor(const gtsam::DiscreteKey& key, string table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys,
const std::vector<double>& table);
DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys,
const std::vector<double>& table);
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table); DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c); DecisionTreeFactor(const gtsam::DiscreteConditional& c);
@ -59,6 +64,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
size_t cardinality(gtsam::Key j) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const; size_t cardinality(gtsam::Key j) const;
@ -66,6 +73,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const;
gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; gtsam::DecisionTreeFactor* max(size_t nrFrontals) const;
gtsam::DecisionTreeFactor* max(const gtsam::Ordering& keys) const;
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
@ -203,10 +211,16 @@ class DiscreteBayesTreeClique {
DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional); DiscreteBayesTreeClique(const gtsam::DiscreteConditional* conditional);
const gtsam::DiscreteConditional* conditional() const; const gtsam::DiscreteConditional* conditional() const;
bool isRoot() const; bool isRoot() const;
size_t nrChildren() const;
const gtsam::DiscreteBayesTreeClique* operator[](size_t i) const;
void print(string s = "DiscreteBayesTreeClique",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printSignature( void printSignature(
const string& s = "Clique: ", const string& s = "Clique: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
double evaluate(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
}; };
class DiscreteBayesTree { class DiscreteBayesTree {
@ -220,6 +234,9 @@ class DiscreteBayesTree {
bool empty() const; bool empty() const;
const DiscreteBayesTreeClique* operator[](size_t j) const; const DiscreteBayesTreeClique* operator[](size_t j) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
string dot(const gtsam::KeyFormatter& keyFormatter = string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, void saveGraph(string s,
@ -242,9 +259,9 @@ class DiscreteBayesTree {
class DiscreteLookupTable : gtsam::DiscreteConditional{ class DiscreteLookupTable : gtsam::DiscreteConditional{
DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys, DiscreteLookupTable(size_t nFrontals, const gtsam::DiscreteKeys& keys,
const gtsam::DecisionTreeFactor::ADT& potentials); const gtsam::DecisionTreeFactor::ADT& potentials);
void print( void print(string s = "Discrete Lookup Table: ",
const std::string& s = "Discrete Lookup Table: ", const gtsam::KeyFormatter& keyFormatter =
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
size_t argmax(const gtsam::DiscreteValues& parentsValues) const; size_t argmax(const gtsam::DiscreteValues& parentsValues) const;
}; };
@ -263,6 +280,14 @@ class DiscreteLookupDAG {
}; };
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateDiscrete(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
std::pair<gtsam::DiscreteConditional*, gtsam::DecisionTreeFactor*>
EliminateForMPE(const gtsam::DiscreteFactorGraph& factors,
const gtsam::Ordering& frontalKeys);
class DiscreteFactorGraph { class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet);
@ -277,6 +302,7 @@ class DiscreteFactorGraph {
void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec); void add(const gtsam::DiscreteKey& j, const std::vector<double>& spec);
void add(const gtsam::DiscreteKeys& keys, string spec); void add(const gtsam::DiscreteKeys& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, string spec); void add(const std::vector<gtsam::DiscreteKey>& keys, string spec);
void add(const std::vector<gtsam::DiscreteKey>& keys, const std::vector<double>& spec);
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
@ -290,25 +316,46 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteBayesNet sumProduct(); gtsam::DiscreteBayesNet sumProduct(
gtsam::DiscreteBayesNet sumProduct(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet sumProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteLookupDAG maxProduct(); gtsam::DiscreteLookupDAG maxProduct(
gtsam::DiscreteLookupDAG maxProduct(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(); gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::DiscreteBayesNet* eliminateSequential(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet* eliminateSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*> pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering); eliminatePartialSequential(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesNet*, gtsam::DiscreteFactorGraph*>
eliminatePartialSequential(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(); gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::DiscreteBayesTree* eliminateMultifrontal(gtsam::Ordering::OrderingType type); gtsam::Ordering::OrderingType type = gtsam::Ordering::COLAMD);
gtsam::DiscreteBayesTree* eliminateMultifrontal(const gtsam::Ordering& ordering); gtsam::DiscreteBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type,
const gtsam::DiscreteFactorGraph::Eliminate& function);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*> pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering); eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
pair<gtsam::DiscreteBayesTree*, gtsam::DiscreteFactorGraph*>
eliminatePartialMultifrontal(
const gtsam::Ordering& ordering,
const gtsam::DiscreteFactorGraph::Eliminate& function);
string dot( string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
@ -328,4 +375,41 @@ class DiscreteFactorGraph {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/discrete/DiscreteEliminationTree.h>
class DiscreteEliminationTree {
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
DiscreteEliminationTree(const gtsam::DiscreteFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/discrete/DiscreteJunctionTree.h>
class DiscreteCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::DiscreteFactorGraph factors;
const gtsam::DiscreteCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class DiscreteJunctionTree {
DiscreteJunctionTree(const gtsam::DiscreteEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::DiscreteCluster& operator[](size_t i) const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -25,6 +25,7 @@
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Symbol.h>
#include <iomanip> #include <iomanip>
@ -75,6 +76,19 @@ struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ************************************************************************** */
// Test char labels and int range
/* ************************************************************************** */
// Create a decision stump one one variable 'a' with values 10 and 20.
TEST(DecisionTree, Constructor) {
DecisionTree<char, int> tree('a', 10, 20);
// Evaluate the tree on an assignment to the variable.
EXPECT_LONGS_EQUAL(10, tree({{'a', 0}}));
EXPECT_LONGS_EQUAL(20, tree({{'a', 1}}));
}
/* ************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ************************************************************************** */ /* ************************************************************************** */
@ -118,18 +132,47 @@ struct Ring {
static inline int mul(const int& a, const int& b) { return a * b; } static inline int mul(const int& a, const int& b) { return a * b; }
}; };
/* ************************************************************************** */
// Check that creating decision trees respects key order.
TEST(DecisionTree, ConstructorOrder) {
// Create labels
string A("A"), B("B");
const std::vector<int> ys1 = {1, 2, 3, 4};
DT tree1({{B, 2}, {A, 2}}, ys1); // faster version, as B is "higher" than A!
const std::vector<int> ys2 = {1, 3, 2, 4};
DT tree2({{A, 2}, {B, 2}}, ys2); // slower version !
// Both trees will be the same, tree is order from high to low labels.
// Choice(B)
// 0 Choice(A)
// 0 0 Leaf 1
// 0 1 Leaf 2
// 1 Choice(A)
// 1 0 Leaf 3
// 1 1 Leaf 4
EXPECT(tree2.equals(tree1));
// Check the values are as expected by calling the () operator:
EXPECT_LONGS_EQUAL(1, tree1({{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(3, tree1({{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(2, tree1({{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, tree1({{A, 1}, {B, 1}}));
}
/* ************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, Example) {
// Create labels // Create labels
string A("A"), B("B"), C("C"); string A("A"), B("B"), C("C");
// create a value // Create assignments using brace initialization:
Assignment<string> x00, x01, x10, x11; Assignment<string> x00{{A, 0}, {B, 0}};
x00[A] = 0, x00[B] = 0; Assignment<string> x01{{A, 0}, {B, 1}};
x01[A] = 0, x01[B] = 1; Assignment<string> x10{{A, 1}, {B, 0}};
x10[A] = 1, x10[B] = 0; Assignment<string> x11{{A, 1}, {B, 1}};
x11[A] = 1, x11[B] = 1;
// empty // empty
DT empty; DT empty;
@ -241,8 +284,7 @@ TEST(DecisionTree, ConvertValuesOnly) {
StringBoolTree f2(f1, bool_of_int); StringBoolTree f2(f1, bool_of_int);
// Check a value // Check a value
Assignment<string> x00; Assignment<string> x00 {{A, 0}, {B, 0}};
x00["A"] = 0, x00["B"] = 0;
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
@ -266,10 +308,11 @@ TEST(DecisionTree, ConvertBoth) {
// Check some values // Check some values
Assignment<Label> x00, x01, x10, x11; Assignment<Label> x00, x01, x10, x11;
x00[X] = 0, x00[Y] = 0; x00 = {{X, 0}, {Y, 0}};
x01[X] = 0, x01[Y] = 1; x01 = {{X, 0}, {Y, 1}};
x10[X] = 1, x10[Y] = 0; x10 = {{X, 1}, {Y, 0}};
x11[X] = 1, x11[Y] = 1; x11 = {{X, 1}, {Y, 1}};
EXPECT(!f2(x00)); EXPECT(!f2(x00));
EXPECT(!f2(x01)); EXPECT(!f2(x01));
EXPECT(f2(x10)); EXPECT(f2(x10));

View File

@ -27,6 +27,18 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */
TEST(DecisionTreeFactor, ConstructorsMatch) {
// Declare two keys
DiscreteKey X(0, 2), Y(1, 3);
// Create with vector and with string
const std::vector<double> table {2, 5, 3, 6, 4, 7};
DecisionTreeFactor f1({X, Y}, table);
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
EXPECT(assert_equal(f1, f2));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DecisionTreeFactor, constructors) TEST( DecisionTreeFactor, constructors)
{ {
@ -41,21 +53,18 @@ TEST( DecisionTreeFactor, constructors)
EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size()); EXPECT_LONGS_EQUAL(3,f3.size());
DiscreteValues values; DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
values[0] = 1; // x EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
values[1] = 2; // y EXPECT_DOUBLES_EQUAL(7, f2(x121), 1e-9);
values[2] = 1; // z EXPECT_DOUBLES_EQUAL(75, f3(x121), 1e-9);
EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(x121)), f1.error(x121), 1e-9);
// Construct from DiscreteConditional // Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4"); DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
DecisionTreeFactor f4(conditional); DecisionTreeFactor f4(conditional);
EXPECT_DOUBLES_EQUAL(0.8, f4(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.8, f4(x121), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -16,23 +16,24 @@
*/ */
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteBayesTree.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesNet.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
using namespace std;
using namespace gtsam; using namespace gtsam;
static constexpr bool debug = false; static constexpr bool debug = false;
/* ************************************************************************* */ /* ************************************************************************* */
struct TestFixture { struct TestFixture {
vector<DiscreteKey> keys; DiscreteKeys keys;
std::vector<DiscreteValues> assignments;
DiscreteBayesNet bayesNet; DiscreteBayesNet bayesNet;
std::shared_ptr<DiscreteBayesTree> bayesTree; std::shared_ptr<DiscreteBayesTree> bayesTree;
@ -47,6 +48,9 @@ struct TestFixture {
keys.push_back(key_i); keys.push_back(key_i);
} }
// Enumerate all assignments.
assignments = DiscreteValues::CartesianProduct(keys);
// Create thin-tree Bayesnet. // Create thin-tree Bayesnet.
bayesNet.add(keys[14] % "1/3"); bayesNet.add(keys[14] % "1/3");
@ -74,9 +78,9 @@ struct TestFixture {
}; };
/* ************************************************************************* */ /* ************************************************************************* */
// Check that BN and BT give the same answer on all configurations
TEST(DiscreteBayesTree, ThinTree) { TEST(DiscreteBayesTree, ThinTree) {
const TestFixture self; TestFixture self;
const auto& keys = self.keys;
if (debug) { if (debug) {
GTSAM_PRINT(self.bayesNet); GTSAM_PRINT(self.bayesNet);
@ -95,47 +99,56 @@ TEST(DiscreteBayesTree, ThinTree) {
EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals())); EXPECT_LONGS_EQUAL(i, *(clique_i->conditional_->beginFrontals()));
} }
auto R = self.bayesTree->roots().front(); for (const auto& x : self.assignments) {
// Check whether BN and BT give the same answer on all configurations
auto allPosbValues = DiscreteValues::CartesianProduct(
keys[0] & keys[1] & keys[2] & keys[3] & keys[4] & keys[5] & keys[6] &
keys[7] & keys[8] & keys[9] & keys[10] & keys[11] & keys[12] & keys[13] &
keys[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteValues x = allPosbValues[i];
double expected = self.bayesNet.evaluate(x); double expected = self.bayesNet.evaluate(x);
double actual = self.bayesTree->evaluate(x); double actual = self.bayesTree->evaluate(x);
DOUBLES_EQUAL(expected, actual, 1e-9); DOUBLES_EQUAL(expected, actual, 1e-9);
} }
}
// Calculate all some marginals for DiscreteValues==all1 /* ************************************************************************* */
Vector marginals = Vector::Zero(15); // Check calculation of separator marginals
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, TEST(DiscreteBayesTree, SeparatorMarginals) {
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, TestFixture self;
joint_4_11 = 0, joint_11_13 = 0, joint_11_13_14 = 0,
joint_11_12_13_14 = 0, joint_9_11_12_13 = 0, joint_8_11_12_13 = 0; // Calculate some marginals for DiscreteValues==all1
for (size_t i = 0; i < allPosbValues.size(); ++i) { double marginal_14 = 0, joint_8_12 = 0;
DiscreteValues x = allPosbValues[i]; for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x); double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[8] && x[12]) joint_8_12 += px; if (x[8] && x[12]) joint_8_12 += px;
if (x[2]) { if (x[14]) marginal_14 += px;
if (x[8]) joint82 += px; }
if (x[1]) joint12 += px; DiscreteValues all1 = self.assignments.back();
}
if (x[4]) { // check separator marginal P(S0)
if (x[2]) joint24 += px; auto clique = (*self.bayesTree)[0];
if (x[5]) joint45 += px; DiscreteFactorGraph separatorMarginal0 =
if (x[6]) joint46 += px; clique->separatorMarginal(EliminateDiscrete);
if (x[11]) joint_4_11 += px; DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
}
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginal_14, separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
}
/* ************************************************************************* */
// Check shortcuts in the tree
TEST(DiscreteBayesTree, Shortcuts) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
double joint_11_13 = 0, joint_11_13_14 = 0, joint_11_12_13_14 = 0,
joint_9_11_12_13 = 0, joint_8_11_12_13 = 0;
for (auto& x : self.assignments) {
double px = self.bayesTree->evaluate(x);
if (x[11] && x[13]) { if (x[11] && x[13]) {
joint_11_13 += px; joint_11_13 += px;
if (x[8] && x[12]) joint_8_11_12_13 += px; if (x[8] && x[12]) joint_8_11_12_13 += px;
@ -148,32 +161,12 @@ TEST(DiscreteBayesTree, ThinTree) {
} }
} }
} }
DiscreteValues all1 = allPosbValues.back(); DiscreteValues all1 = self.assignments.back();
// check separator marginal P(S0) auto R = self.bayesTree->roots().front();
auto clique = (*self.bayesTree)[0];
DiscreteFactorGraph separatorMarginal0 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
clique = (*self.bayesTree)[11];
DiscreteFactorGraph separatorMarginal11 =
clique->separatorMarginal(EliminateDiscrete);
LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root // check shortcut P(S9||R) to root
clique = (*self.bayesTree)[9]; auto clique = (*self.bayesTree)[9];
DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete); DiscreteBayesNet shortcut = clique->shortcut(R, EliminateDiscrete);
LONGS_EQUAL(1, shortcut.size()); LONGS_EQUAL(1, shortcut.size());
DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9); DOUBLES_EQUAL(joint_11_13_14 / joint_11_13, shortcut.evaluate(all1), 1e-9);
@ -202,15 +195,67 @@ TEST(DiscreteBayesTree, ThinTree) {
shortcut.print("shortcut:"); shortcut.print("shortcut:");
} }
} }
}
/* ************************************************************************* */
// Check all marginals
TEST(DiscreteBayesTree, MarginalFactors) {
TestFixture self;
Vector marginals = Vector::Zero(15);
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
}
// Check all marginals // Check all marginals
DiscreteFactor::shared_ptr marginalFactor; DiscreteValues all1 = self.assignments.back();
for (size_t i = 0; i < 15; i++) { for (size_t i = 0; i < 15; i++) {
marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete); auto marginalFactor = self.bayesTree->marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1); double actual = (*marginalFactor)(all1);
DOUBLES_EQUAL(marginals[i], actual, 1e-9); DOUBLES_EQUAL(marginals[i], actual, 1e-9);
} }
}
/* ************************************************************************* */
// Check a number of joint marginals.
TEST(DiscreteBayesTree, Joints) {
TestFixture self;
// Calculate some marginals for DiscreteValues==all1
Vector marginals = Vector::Zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
for (size_t i = 0; i < self.assignments.size(); ++i) {
DiscreteValues& x = self.assignments[i];
double px = self.bayesTree->evaluate(x);
for (size_t i = 0; i < 15; i++)
if (x[i]) marginals[i] += px;
if (x[12] && x[14]) {
joint_12_14 += px;
if (x[9]) joint_9_12_14 += px;
if (x[8]) joint_8_12_14 += px;
}
if (x[2]) {
if (x[8]) joint82 += px;
if (x[1]) joint12 += px;
}
if (x[4]) {
if (x[2]) joint24 += px;
if (x[5]) joint45 += px;
if (x[6]) joint46 += px;
if (x[11]) joint_4_11 += px;
}
}
// regression tests:
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
DiscreteValues all1 = self.assignments.back();
DiscreteBayesNet::shared_ptr actualJoint; DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8, 2) // Check joint P(8, 2)
@ -240,8 +285,8 @@ TEST(DiscreteBayesTree, ThinTree) {
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscreteBayesTree, Dot) { TEST(DiscreteBayesTree, Dot) {
const TestFixture self; TestFixture self;
string actual = self.bayesTree->dot(); std::string actual = self.bayesTree->dot();
EXPECT(actual == EXPECT(actual ==
"digraph G{\n" "digraph G{\n"
"0[label=\"13, 11, 6, 7\"];\n" "0[label=\"13, 11, 6, 7\"];\n"
@ -268,6 +313,62 @@ TEST(DiscreteBayesTree, Dot) {
"}"); "}");
} }
/* ************************************************************************* */
// Check that we can have a multi-frontal lookup table
TEST(DiscreteBayesTree, Lookup) {
using gtsam::symbol_shorthand::A;
using gtsam::symbol_shorthand::X;
// Make a small planning-like graph: 3 states, 2 actions
DiscreteFactorGraph graph;
const DiscreteKey x1{X(1), 3}, x2{X(2), 3}, x3{X(3), 3};
const DiscreteKey a1{A(1), 2}, a2{A(2), 2};
// Constraint on start and goal
graph.add(DiscreteKeys{x1}, std::vector<double>{1, 0, 0});
graph.add(DiscreteKeys{x3}, std::vector<double>{0, 0, 1});
// Should I stay or should I go?
// "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
const double r = 10;
std::vector<double> table{
r, 0, 0, 0, r, 0, // x1 = 0
0, r, 0, 0, 0, r, // x1 = 1
0, 0, r, 0, 0, r // x1 = 2
};
graph.add(DiscreteKeys{x1, a1, x2}, table);
graph.add(DiscreteKeys{x2, a2, x3}, table);
// eliminate for MPE (maximum probable explanation).
Ordering ordering{A(2), X(3), X(1), A(1), X(2)};
auto lookup = graph.eliminateMultifrontal(ordering, EliminateForMPE);
// Check that the lookup table is correct
EXPECT_LONGS_EQUAL(2, lookup->size());
auto lookup_x1_a1_x2 = (*lookup)[X(1)]->conditional();
EXPECT_LONGS_EQUAL(3, lookup_x1_a1_x2->frontals().size());
// check that sum is 1.0 (not 100, as we now normalize)
DiscreteValues empty;
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2->sum(3))(empty), 1e-9);
// And that only non-zero reward is for x1 a1 x2 == 0 1 1
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_x1_a1_x2)({{X(1),0},{A(1),1},{X(2),1}}), 1e-9);
auto lookup_a2_x3 = (*lookup)[X(3)]->conditional();
// check that the sum depends on x2 and is non-zero only for x2 \in {1,2}
auto sum_x2 = lookup_a2_x3->sum(2);
EXPECT_DOUBLES_EQUAL(0, (*sum_x2)({{X(2),0}}), 1e-9);
EXPECT_DOUBLES_EQUAL(1.0, (*sum_x2)({{X(2),1}}), 1e-9);
EXPECT_DOUBLES_EQUAL(2.0, (*sum_x2)({{X(2),2}}), 1e-9);
EXPECT_LONGS_EQUAL(2, lookup_a2_x3->frontals().size());
// And that the non-zero rewards are for
// x2 a2 x3 == 1 1 2
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),1},{A(2),1},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 0 2
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),0},{X(3),2}}), 1e-9);
// x2 a2 x3 == 2 1 2
EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -93,8 +93,7 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) { for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count() << " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << << " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
endl;
} }
} }
@ -361,6 +360,39 @@ TEST(TableFactor, htmlWithValueFormatter) {
EXPECT(actual == expected); EXPECT(actual == expected);
} }
/* ************************************************************************* */
TEST(TableFactor, Unary) {
// Declare a bunch of keys
DiscreteKey X(0, 2), Y(1, 3);
// Create factors
TableFactor f(X & Y, "2 5 3 6 2 7");
auto op = [](const double x) { return 2 * x; };
auto g = f.apply(op);
TableFactor expected(X & Y, "4 10 6 12 4 14");
EXPECT(assert_equal(g, expected));
auto sq_op = [](const double x) { return x * x; };
auto g_sq = f.apply(sq_op);
TableFactor expected_sq(X & Y, "4 25 9 36 4 49");
EXPECT(assert_equal(g_sq, expected_sq));
}
/* ************************************************************************* */
TEST(TableFactor, UnaryAssignment) {
// Declare a bunch of keys
DiscreteKey X(0, 2), Y(1, 3);
// Create factors
TableFactor f(X & Y, "2 5 3 6 2 7");
auto op = [](const Assignment<Key>& key, const double x) { return 2 * x; };
auto g = f.apply(op);
TableFactor expected(X & Y, "4 10 6 12 4 14");
EXPECT(assert_equal(g, expected));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -146,7 +146,7 @@ class GTSAM_EXPORT Line3 {
* @param Dline - OptionalJacobian of transformed line with respect to l * @param Dline - OptionalJacobian of transformed line with respect to l
* @return Transformed line in camera frame * @return Transformed line in camera frame
*/ */
friend Line3 transformTo(const Pose3 &wTc, const Line3 &wL, GTSAM_EXPORT friend Line3 transformTo(const Pose3 &wTc, const Line3 &wL,
OptionalJacobian<4, 6> Dpose, OptionalJacobian<4, 6> Dpose,
OptionalJacobian<4, 4> Dline); OptionalJacobian<4, 4> Dline);
}; };

View File

@ -597,6 +597,25 @@ TEST(Rot3, quaternion) {
EXPECT(assert_equal(expected2, actual2)); EXPECT(assert_equal(expected2, actual2));
} }
/* ************************************************************************* */
TEST(Rot3, ConvertQuaternion) {
Eigen::Quaterniond eigenQuaternion;
eigenQuaternion.w() = 1.0;
eigenQuaternion.x() = 2.0;
eigenQuaternion.y() = 3.0;
eigenQuaternion.z() = 4.0;
EXPECT_DOUBLES_EQUAL(1, eigenQuaternion.w(), 1e-9);
EXPECT_DOUBLES_EQUAL(2, eigenQuaternion.x(), 1e-9);
EXPECT_DOUBLES_EQUAL(3, eigenQuaternion.y(), 1e-9);
EXPECT_DOUBLES_EQUAL(4, eigenQuaternion.z(), 1e-9);
Rot3 R(eigenQuaternion);
EXPECT_DOUBLES_EQUAL(1, R.toQuaternion().w(), 1e-9);
EXPECT_DOUBLES_EQUAL(2, R.toQuaternion().x(), 1e-9);
EXPECT_DOUBLES_EQUAL(3, R.toQuaternion().y(), 1e-9);
EXPECT_DOUBLES_EQUAL(4, R.toQuaternion().z(), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
Matrix Cayley(const Matrix& A) { Matrix Cayley(const Matrix& A) {
Matrix::Index n = A.cols(); Matrix::Index n = A.cols();

View File

@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor discreteProbsFactor;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete());
discreteProbsFactor = discreteProbsFactor * f;
}
}
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
}
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) { size_t maxNrLeaves) {
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys(); // Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;
std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;
// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); auto conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete(); discreteProbs = discreteProbs * (*conditional->asDiscrete());
// Convert pointer from conditional to factor Ordering conditional_keys(conditional->frontals());
auto discreteTree = discrete_frontals += conditional_keys;
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete); discrete_factor_idxs.push_back(i);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);
// Add it back to the BayesNet
this->at(i) = conditional;
} }
} }
const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
return prunedDiscreteProbs;
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys DecisionTreeFactor prunedDiscreteProbs =
gttic_(HybridBayesNet_PruneDiscreteConditionals); this->pruneDiscreteConditionals(maxNrLeaves);
DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
gttic_(HybridBayesNet_UpdateDiscreteConditionals); /* To prune, we visitWith every leaf in the GaussianMixture.
this->updateDiscreteConditionals(prunedDiscreteProbs);
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.
* *

View File

@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/ */
VectorValues optimize(const DiscreteValues &assignment) const; VectorValues optimize(const DiscreteValues &assignment) const;
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;
/** /**
* @brief Sample from an incomplete BayesNet, given missing variables. * @brief Sample from an incomplete BayesNet, given missing variables.
* *
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
private: private:
/** /**
* @brief Update the discrete conditionals with the pruned versions. * @brief Prune all the discrete conditionals.
* *
* @param prunedDiscreteProbs * @param maxNrLeaves
*/ */
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */

View File

@ -20,6 +20,7 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>

View File

@ -17,7 +17,6 @@
* @date January, 2023 * @date January, 2023
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -26,7 +25,7 @@ namespace gtsam {
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const { std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys; std::set<DiscreteKey> keys;
for (auto& factor : factors_) { for (auto& factor : factors_) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) { for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key); keys.insert(key);
} }
@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
for (const Key& key : p->continuousKeys()) { for (const Key& key : p->continuousKeys()) {
keys.insert(key); keys.insert(key);
} }
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
} }
} }
return keys; return keys;

View File

@ -48,8 +48,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
// #define HYBRID_TIMING
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// TODO(dellaert): in C++20, we can use std::visit. // TODO(dellaert): in C++20, we can use std::visit.
continue; continue;
} }
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
// since we want to eliminate continuous values only. // since we want to eliminate continuous values only.
continue; continue;
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &f : factors) { for (auto &f : factors) {
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) { if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(dtf); dfg.push_back(df);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique. // Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here. // TODO(dellaert): is this correct? If so explain here.
@ -262,6 +260,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}; };
DecisionTree<Key, double> probabilities(eliminationResults, probability); DecisionTree<Key, double> probabilities(eliminationResults, probability);
return { return {
std::make_shared<HybridConditional>(gaussianMixture), std::make_shared<HybridConditional>(gaussianMixture),
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)}; std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
@ -348,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// When the number of assignments is large we may encounter stack overflows. // When the number of assignments is large we may encounter stack overflows.
// However this is also the case with iSAM2, so no pressure :) // However this is also the case with iSAM2, so no pressure :)
// PREPROCESS: Identify the nature of the current elimination // Check the factors:
// TODO(dellaert): just check the factors:
// 1. if all factors are discrete, then we can do discrete elimination: // 1. if all factors are discrete, then we can do discrete elimination:
// 2. if all factors are continuous, then we can do continuous elimination: // 2. if all factors are continuous, then we can do continuous elimination:
// 3. if not, we do hybrid elimination: // 3. if not, we do hybrid elimination:
// First, identify the separator keys, i.e. all keys that are not frontal. bool only_discrete = true, only_continuous = true;
KeySet separatorKeys;
for (auto &&factor : factors) { for (auto &&factor : factors) {
separatorKeys.insert(factor->begin(), factor->end()); if (auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
} if (hybrid_factor->isDiscrete()) {
// remove frontals from separator only_continuous = false;
for (auto &k : frontalKeys) { } else if (hybrid_factor->isContinuous()) {
separatorKeys.erase(k); only_discrete = false;
} } else if (hybrid_factor->isHybrid()) {
only_continuous = false;
// Build a map from keys to DiscreteKeys only_discrete = false;
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap(); }
} else if (auto cont_factor =
// Fill in discrete frontals and continuous frontals. std::dynamic_pointer_cast<GaussianFactor>(factor)) {
std::set<DiscreteKey> discreteFrontals; only_discrete = false;
KeySet continuousFrontals; } else if (auto discrete_factor =
for (auto &k : frontalKeys) { std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { only_continuous = false;
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousFrontals.insert(k);
} }
} }
// Fill in discrete discrete separator keys and continuous separator keys.
std::set<DiscreteKey> discreteSeparatorSet;
KeyVector continuousSeparator;
for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousSeparator.push_back(k);
}
}
// Check if we have any continuous keys:
const bool discrete_only =
continuousFrontals.empty() && continuousSeparator.empty();
// NOTE: We should really defer the product here because of pruning // NOTE: We should really defer the product here because of pruning
if (discrete_only) { if (only_discrete) {
// Case 1: we are only dealing with discrete // Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys); return discreteElimination(factors, frontalKeys);
} else if (mapFromKeyToDiscreteKey.empty()) { } else if (only_continuous) {
// Case 2: we are only dealing with continuous // Case 2: we are only dealing with continuous
return continuousElimination(factors, frontalKeys); return continuousElimination(factors, frontalKeys);
} else { } else {
// Case 3: We are now in the hybrid land! // Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
// Find all the keys in the set of continuous keys
// which are not in the frontal keys. This is our continuous separator.
KeyVector continuousSeparator;
auto continuousKeySet = factors.continuousKeySet();
std::set_difference(
continuousKeySet.begin(), continuousKeySet.end(),
frontalKeysSet.begin(), frontalKeysSet.end(),
std::inserter(continuousSeparator, continuousSeparator.begin()));
// Similarly for the discrete separator.
KeySet discreteSeparatorSet;
std::set<DiscreteKey> discreteSeparator;
auto discreteKeySet = factors.discreteKeySet();
std::set_difference(
discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
frontalKeysSet.end(),
std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
// Convert from set of keys to set of DiscreteKeys
auto discreteKeyMap = factors.discreteKeyMap();
for (auto key : discreteSeparatorSet) {
discreteSeparator.insert(discreteKeyMap.at(key));
}
return hybridElimination(factors, frontalKeys, continuousSeparator, return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet); discreteSeparator);
} }
} }
@ -429,7 +432,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Add the gaussian factor error to every leaf of the error tree. // Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip. // If factor at `idx` is discrete-only, we skip.
continue; continue;
} else { } else {

View File

@ -40,6 +40,7 @@ class HybridEliminationTree;
class HybridBayesTree; class HybridBayesTree;
class HybridJunctionTree; class HybridJunctionTree;
class DecisionTreeFactor; class DecisionTreeFactor;
class TableFactor;
class JacobianFactor; class JacobianFactor;
class HybridValues; class HybridValues;

View File

@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
for (auto& k : hf->discreteKeys()) { for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
} }
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
for (auto& k : hf->discreteKeys()) { for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
} }
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
Data rootData(0); Data rootData(0);
rootData.junctionTreeNode = rootData.junctionTreeNode =
std::make_shared<typename Base::Node>(); // Make a dummy node to gather std::make_shared<typename Base::Node>(); // Make a dummy node to gather
// the junction tree roots // the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPost); Data::ConstructorTraversalVisitorPost);

View File

@ -17,6 +17,7 @@
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) { } else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
linearFG->push_back(gf); linearFG->push_back(gf);
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization. // If discrete-only: doesn't need linearization.
linearFG->push_back(f); linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {

View File

@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
addConditionals(graph, hybridBayesNet_, ordering); addConditionals(graph, hybridBayesNet_, ordering);
// Eliminate. // Eliminate.
auto bayesNetFragment = graph.eliminateSequential(ordering); HybridBayesNet::shared_ptr bayesNetFragment =
graph.eliminateSequential(ordering);
/// Prune /// Prune
if (maxNrLeaves) { if (maxNrLeaves) {
@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
HybridGaussianFactorGraph graph(originalGraph); HybridGaussianFactorGraph graph(originalGraph);
HybridBayesNet hybridBayesNet(originalHybridBayesNet); HybridBayesNet hybridBayesNet(originalHybridBayesNet);
// If we are not at the first iteration, means we have conditionals to add. // If hybridBayesNet is not empty,
// it means we have conditionals to add to the factor graph.
if (!hybridBayesNet.empty()) { if (!hybridBayesNet.empty()) {
// We add all relevant conditional mixtures on the last continuous variable // We add all relevant conditional mixtures on the last continuous variable
// in the previous `hybridBayesNet` to the graph // in the previous `hybridBayesNet` to the graph

View File

@ -35,14 +35,11 @@ class HybridValues {
}; };
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
virtual class HybridFactor { virtual class HybridFactor : gtsam::Factor {
void print(string s = "HybridFactor\n", void print(string s = "HybridFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const; bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
// Standard interface: // Standard interface:
double error(const gtsam::HybridValues &values) const; double error(const gtsam::HybridValues &values) const;
@ -179,6 +176,7 @@ class HybridGaussianFactorGraph {
void push_back(const gtsam::HybridBayesTree& bayesTree); void push_back(const gtsam::HybridBayesTree& bayesTree);
void push_back(const gtsam::GaussianMixtureFactor* gmm); void push_back(const gtsam::GaussianMixtureFactor* gmm);
void push_back(gtsam::DecisionTreeFactor* factor); void push_back(gtsam::DecisionTreeFactor* factor);
void push_back(gtsam::TableFactor* factor);
void push_back(gtsam::JacobianFactor* factor); void push_back(gtsam::JacobianFactor* factor);
bool empty() const; bool empty() const;

View File

@ -202,31 +202,16 @@ struct Switching {
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
* E.g. if K=4, we want M0, M1 and M2. * E.g. if K=4, we want M0, M1 and M2.
* *
* @param fg The nonlinear factor graph to which the mode chain is added. * @param fg The factor graph to which the mode chain is added.
*/ */
void addModeChain(HybridNonlinearFactorGraph *fg, template <typename FACTORGRAPH>
void addModeChain(FACTORGRAPH *fg,
std::string discrete_transition_prob = "1/2 3/2") { std::string discrete_transition_prob = "1/2 3/2") {
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1"); fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
for (size_t k = 0; k < K - 2; k++) { for (size_t k = 0; k < K - 2; k++) {
auto parents = {modes[k]}; auto parents = {modes[k]};
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents, fg->template emplace_shared<DiscreteConditional>(
discrete_transition_prob); modes[k + 1], parents, discrete_transition_prob);
}
}
/**
* @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2).
* E.g. if K=4, we want M0, M1 and M2.
*
* @param fg The gaussian factor graph to which the mode chain is added.
*/
void addModeChain(HybridGaussianFactorGraph *fg,
std::string discrete_transition_prob = "1/2 3/2") {
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
for (size_t k = 0; k < K - 2; k++) {
auto parents = {modes[k]};
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
discrete_transition_prob);
} }
} }
}; };

View File

@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
std::string expected = std::string expected =
R"(Hybrid [x1 x2; 1]{ R"(Hybrid [x1 x2; 1]{
Choice(1) Choice(1)
0 Leaf : 0 Leaf [1] :
A[x1] = [ A[x1] = [
0; 0;
0 0
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
b = [ 0 0 ] b = [ 0 0 ]
No noise model No noise model
1 Leaf : 1 Leaf [1] :
A[x1] = [ A[x1] = [
0; 0;
0 0

View File

@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
// Regression test on pruned logProbability tree // Regression test on pruned logProbability tree
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves); AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
logProbability += logProbability +=
posterior->at(4)->asDiscrete()->logProbability(hybridValues); posterior->at(4)->asDiscrete()->logProbability(hybridValues);
// Regression
double density = exp(logProbability); double density = exp(logProbability);
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(density,
1.6078460548731697 * actualTree(discrete_values), 1e-6);
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
1e-9); 1e-9);
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(7, posterior->size()); EXPECT_LONGS_EQUAL(7, posterior->size());
size_t maxNrLeaves = 3; size_t maxNrLeaves = 3;
auto discreteConditionals = posterior->discreteConditionals(); DiscreteConditional discreteConditionals;
for (auto&& conditional : *posterior) {
if (conditional->isDiscrete()) {
discreteConditionals =
discreteConditionals * (*conditional->asDiscrete());
}
}
const DecisionTreeFactor::shared_ptr prunedDecisionTree = const DecisionTreeFactor::shared_ptr prunedDecisionTree =
std::make_shared<DecisionTreeFactor>( std::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves)); discreteConditionals.prune(maxNrLeaves));
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves()); prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); // regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
DecisionTreeFactor::ADT potentials(
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
// Prune! // Prune!
posterior->prune(maxNrLeaves); posterior->prune(maxNrLeaves);
// Functor to verify values against the original_discrete_conditionals // Functor to verify values against the expected_discrete_conditionals
auto checker = [&](const Assignment<Key>& assignment, auto checker = [&](const Assignment<Key>& assignment,
double probability) -> double { double probability) -> double {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
if (prunedDecisionTree->operator()(choices) == 0) { if (prunedDecisionTree->operator()(choices) == 0) {
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
} else { } else {
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
1e-9); 1e-9);
} }
return 0.0; return 0.0;

View File

@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) { for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f); auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor); assert(discreteFactor);
dfg.push_back(discreteFactor); dfg.push_back(discreteFactor);
} }

View File

@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) {
EXPECT(assert_equal(expected_continuous, result)); EXPECT(assert_equal(expected_continuous, result));
} }
/****************************************************************************/
// Test approximate inference with an additional pruning step.
TEST(HybridEstimation, ISAM) {
size_t K = 15;
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
// Ground truth discrete seq
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
// Switching example of robot moving in 1D
// with given measurements and equal mode priors.
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
HybridNonlinearISAM isam;
HybridNonlinearFactorGraph graph;
Values initial;
// gttic_(Estimation);
// Add the X(0) prior
graph.push_back(switching.nonlinearFactorGraph.at(0));
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
HybridGaussianFactorGraph linearized;
for (size_t k = 1; k < K; k++) {
// Motion Model
graph.push_back(switching.nonlinearFactorGraph.at(k));
// Measurement
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
isam.update(graph, initial, 3);
// isam.bayesTree().print("\n\n");
graph.resize(0);
initial.clear();
}
Values result = isam.estimate();
DiscreteValues assignment = isam.assignment();
DiscreteValues expected_discrete;
for (size_t k = 0; k < K - 1; k++) {
expected_discrete[M(k)] = discrete_seq[k];
}
EXPECT(assert_equal(expected_discrete, assignment));
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
}
/** /**
* @brief A function to get a specific 1D robot motion problem as a linearized * @brief A function to get a specific 1D robot motion problem as a linearized
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous * factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous

View File

@ -18,7 +18,9 @@
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/JacobianFactor.h>
#include <gtsam/nonlinear/PriorFactor.h> #include <gtsam/nonlinear/PriorFactor.h>
using namespace std; using namespace std;
@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
HybridFactorGraph fg; HybridFactorGraph fg;
} }
/* ************************************************************************* */
// Test if methods to get keys work as expected.
TEST(HybridFactorGraph, Keys) {
HybridGaussianFactorGraph hfg;
// Add prior on x0
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
// Add factor between x0 and x1
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
// Add a gaussian mixture factor ϕ(x1, c1)
DiscreteKey m1(M(1), 2);
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
KeySet expected_continuous{X(0), X(1)};
EXPECT(
assert_container_equality(expected_continuous, hfg.continuousKeySet()));
KeySet expected_discrete{M(1)};
EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -903,7 +903,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// Test resulting posterior Bayes net has correct size: // Test resulting posterior Bayes net has correct size:
EXPECT_LONGS_EQUAL(8, posterior->size()); EXPECT_LONGS_EQUAL(8, posterior->size());
// TODO(dellaert): this test fails - no idea why. // Ratio test
EXPECT(ratioTest(bn, measurements, *posterior)); EXPECT(ratioTest(bn, measurements, *posterior));
} }

View File

@ -492,7 +492,7 @@ factor 0:
factor 1: factor 1:
Hybrid [x0 x1; m0]{ Hybrid [x0 x1; m0]{
Choice(m0) Choice(m0)
0 Leaf : 0 Leaf [1] :
A[x0] = [ A[x0] = [
-1 -1
] ]
@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{
b = [ -1 ] b = [ -1 ]
No noise model No noise model
1 Leaf : 1 Leaf [1] :
A[x0] = [ A[x0] = [
-1 -1
] ]
@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{
factor 2: factor 2:
Hybrid [x1 x2; m1]{ Hybrid [x1 x2; m1]{
Choice(m1) Choice(m1)
0 Leaf : 0 Leaf [1] :
A[x1] = [ A[x1] = [
-1 -1
] ]
@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{
b = [ -1 ] b = [ -1 ]
No noise model No noise model
1 Leaf : 1 Leaf [1] :
A[x1] = [ A[x1] = [
-1 -1
] ]
@ -550,16 +550,16 @@ factor 4:
b = [ -10 ] b = [ -10 ]
No noise model No noise model
factor 5: P( m0 ): factor 5: P( m0 ):
Leaf 0.5 Leaf [2] 0.5
factor 6: P( m1 | m0 ): factor 6: P( m1 | m0 ):
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf 0.33333333 0 0 Leaf [1] 0.33333333
0 1 Leaf 0.6 0 1 Leaf [1] 0.6
1 Choice(m0) 1 Choice(m0)
1 0 Leaf 0.66666667 1 0 Leaf [1] 0.66666667
1 1 Leaf 0.4 1 1 Leaf [1] 0.4
)"; )";
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
@ -570,13 +570,13 @@ size: 3
conditional 0: Hybrid P( x0 | x1 m0) conditional 0: Hybrid P( x0 | x1 m0)
Discrete Keys = (m0, 2), Discrete Keys = (m0, 2),
Choice(m0) Choice(m0)
0 Leaf p(x0 | x1) 0 Leaf [1] p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
S[x1] = [ -0.0995037 ] S[x1] = [ -0.0995037 ]
d = [ -9.85087 ] d = [ -9.85087 ]
No noise model No noise model
1 Leaf p(x0 | x1) 1 Leaf [1] p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
S[x1] = [ -0.0995037 ] S[x1] = [ -0.0995037 ]
d = [ -9.95037 ] d = [ -9.95037 ]
@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf p(x1 | x2) 0 0 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -9.99901 ] d = [ -9.99901 ]
No noise model No noise model
0 1 Leaf p(x1 | x2) 0 1 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -9.90098 ] d = [ -9.90098 ]
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf p(x1 | x2) 1 0 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -10.098 ] d = [ -10.098 ]
No noise model No noise model
1 1 Leaf p(x1 | x2) 1 1 Leaf [1] p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -10 ] d = [ -10 ]
@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf p(x2) 0 0 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1489 ] d = [ -10.1489 ]
mean: 1 elements mean: 1 elements
x2: -1.0099 x2: -1.0099
No noise model No noise model
0 1 Leaf p(x2) 0 1 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1479 ] d = [ -10.1479 ]
mean: 1 elements mean: 1 elements
@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf p(x2) 1 0 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0504 ] d = [ -10.0504 ]
mean: 1 elements mean: 1 elements
x2: -1.0001 x2: -1.0001
No noise model No noise model
1 1 Leaf p(x2) 1 1 Leaf [1] p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0494 ] d = [ -10.0494 ]
mean: 1 elements mean: 1 elements

View File

@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
R"(Hybrid [x1 x2; 1] R"(Hybrid [x1 x2; 1]
MixtureFactor MixtureFactor
Choice(1) Choice(1)
0 Leaf Nonlinear factor on 2 keys 0 Leaf [1] Nonlinear factor on 2 keys
1 Leaf Nonlinear factor on 2 keys 1 Leaf [1] Nonlinear factor on 2 keys
)"; )";
EXPECT(assert_print_equal(expected, mixtureFactor)); EXPECT(assert_print_equal(expected, mixtureFactor));
} }

View File

@ -140,9 +140,15 @@ namespace gtsam {
/** Access the conditional */ /** Access the conditional */
const sharedConditional& conditional() const { return conditional_; } const sharedConditional& conditional() const { return conditional_; }
/** is this the root of a Bayes tree ? */ /// Return true if this clique is the root of a Bayes tree.
inline bool isRoot() const { return parent_.expired(); } inline bool isRoot() const { return parent_.expired(); }
/// Return the number of children.
size_t nrChildren() const { return children.size(); }
/// Return the child at index i.
const derived_ptr operator[](size_t i) const { return children.at(i); }
/** The size of subtree rooted at this clique, i.e., nr of Cliques */ /** The size of subtree rooted at this clique, i.e., nr of Cliques */
size_t treeSize() const; size_t treeSize() const;

View File

@ -49,7 +49,7 @@ class ClusterTree {
virtual ~Cluster() {} virtual ~Cluster() {}
const Cluster& operator[](size_t i) const { const Cluster& operator[](size_t i) const {
return *(children[i]); return *(children.at(i));
} }
/// Construct from factors associated with a single key /// Construct from factors associated with a single key
@ -161,7 +161,7 @@ class ClusterTree {
} }
const Cluster& operator[](size_t i) const { const Cluster& operator[](size_t i) const {
return *(roots_[i]); return *(roots_.at(i));
} }
/// @} /// @}

View File

@ -74,8 +74,9 @@ namespace gtsam {
EliminationTreeType etree(asDerived(), (*variableIndex).get(), ordering); EliminationTreeType etree(asDerived(), (*variableIndex).get(), ordering);
const auto [bayesNet, factorGraph] = etree.eliminate(function); const auto [bayesNet, factorGraph] = etree.eliminate(function);
// If any factors are remaining, the ordering was incomplete // If any factors are remaining, the ordering was incomplete
if(!factorGraph->empty()) if(!factorGraph->empty()) {
throw InconsistentEliminationRequested(); throw InconsistentEliminationRequested(factorGraph->keys());
}
// Return the Bayes net // Return the Bayes net
return bayesNet; return bayesNet;
} }
@ -136,8 +137,9 @@ namespace gtsam {
JunctionTreeType junctionTree(etree); JunctionTreeType junctionTree(etree);
const auto [bayesTree, factorGraph] = junctionTree.eliminate(function); const auto [bayesTree, factorGraph] = junctionTree.eliminate(function);
// If any factors are remaining, the ordering was incomplete // If any factors are remaining, the ordering was incomplete
if(!factorGraph->empty()) if(!factorGraph->empty()) {
throw InconsistentEliminationRequested(); throw InconsistentEliminationRequested(factorGraph->keys());
}
// Return the Bayes tree // Return the Bayes tree
return bayesTree; return bayesTree;
} }

View File

@ -51,12 +51,12 @@ namespace gtsam {
* algorithms. Any factor graph holding eliminateable factors can derive from this class to * algorithms. Any factor graph holding eliminateable factors can derive from this class to
* expose functions for computing marginals, conditional marginals, doing multifrontal and * expose functions for computing marginals, conditional marginals, doing multifrontal and
* sequential elimination, etc. */ * sequential elimination, etc. */
template<class FACTORGRAPH> template<class FACTOR_GRAPH>
class EliminateableFactorGraph class EliminateableFactorGraph
{ {
private: private:
typedef EliminateableFactorGraph<FACTORGRAPH> This; ///< Typedef to this class. typedef EliminateableFactorGraph<FACTOR_GRAPH> This; ///< Typedef to this class.
typedef FACTORGRAPH FactorGraphType; ///< Typedef to factor graph type typedef FACTOR_GRAPH FactorGraphType; ///< Typedef to factor graph type
// Base factor type stored in this graph (private because derived classes will get this from // Base factor type stored in this graph (private because derived classes will get this from
// their FactorGraph base class) // their FactorGraph base class)
typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType; typedef typename EliminationTraits<FactorGraphType>::FactorType _FactorType;

View File

@ -30,7 +30,7 @@
namespace gtsam { namespace gtsam {
class Ordering: public KeyVector { class GTSAM_EXPORT Ordering: public KeyVector {
protected: protected:
typedef KeyVector Base; typedef KeyVector Base;
@ -45,7 +45,6 @@ public:
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
/// Create an empty ordering /// Create an empty ordering
GTSAM_EXPORT
Ordering() { Ordering() {
} }
@ -101,7 +100,7 @@ public:
} }
/// Compute a fill-reducing ordering using COLAMD from a VariableIndex. /// Compute a fill-reducing ordering using COLAMD from a VariableIndex.
static GTSAM_EXPORT Ordering Colamd(const VariableIndex& variableIndex); static Ordering Colamd(const VariableIndex& variableIndex);
/// Compute a fill-reducing ordering using constrained COLAMD from a factor graph (see details /// Compute a fill-reducing ordering using constrained COLAMD from a factor graph (see details
/// for note on performance). This internally builds a VariableIndex so if you already have a /// for note on performance). This internally builds a VariableIndex so if you already have a
@ -126,7 +125,7 @@ public:
/// variables in \c constrainLast will be ordered in the same order specified in the KeyVector /// variables in \c constrainLast will be ordered in the same order specified in the KeyVector
/// \c constrainLast. If \c forceOrder is false, the variables in \c constrainLast will be /// \c constrainLast. If \c forceOrder is false, the variables in \c constrainLast will be
/// ordered after all the others, but will be rearranged by CCOLAMD to reduce fill-in as well. /// ordered after all the others, but will be rearranged by CCOLAMD to reduce fill-in as well.
static GTSAM_EXPORT Ordering ColamdConstrainedLast( static Ordering ColamdConstrainedLast(
const VariableIndex& variableIndex, const KeyVector& constrainLast, const VariableIndex& variableIndex, const KeyVector& constrainLast,
bool forceOrder = false); bool forceOrder = false);
@ -154,7 +153,7 @@ public:
/// KeyVector \c constrainFirst. If \c forceOrder is false, the variables in \c /// KeyVector \c constrainFirst. If \c forceOrder is false, the variables in \c
/// constrainFirst will be ordered before all the others, but will be rearranged by CCOLAMD to /// constrainFirst will be ordered before all the others, but will be rearranged by CCOLAMD to
/// reduce fill-in as well. /// reduce fill-in as well.
static GTSAM_EXPORT Ordering ColamdConstrainedFirst( static Ordering ColamdConstrainedFirst(
const VariableIndex& variableIndex, const VariableIndex& variableIndex,
const KeyVector& constrainFirst, bool forceOrder = false); const KeyVector& constrainFirst, bool forceOrder = false);
@ -183,7 +182,7 @@ public:
/// appear in \c groups in arbitrary order. Any variables not present in \c groups will be /// appear in \c groups in arbitrary order. Any variables not present in \c groups will be
/// assigned to group 0. This function simply fills the \c cmember argument to CCOLAMD with the /// assigned to group 0. This function simply fills the \c cmember argument to CCOLAMD with the
/// supplied indices, see the CCOLAMD documentation for more information. /// supplied indices, see the CCOLAMD documentation for more information.
static GTSAM_EXPORT Ordering ColamdConstrained( static Ordering ColamdConstrained(
const VariableIndex& variableIndex, const FastMap<Key, int>& groups); const VariableIndex& variableIndex, const FastMap<Key, int>& groups);
/// Return a natural Ordering. Typically used by iterative solvers /// Return a natural Ordering. Typically used by iterative solvers
@ -197,11 +196,11 @@ public:
/// METIS Formatting function /// METIS Formatting function
template<class FACTOR_GRAPH> template<class FACTOR_GRAPH>
static GTSAM_EXPORT void CSRFormat(std::vector<int>& xadj, static void CSRFormat(std::vector<int>& xadj,
std::vector<int>& adj, const FACTOR_GRAPH& graph); std::vector<int>& adj, const FACTOR_GRAPH& graph);
/// Compute an ordering determined by METIS from a VariableIndex /// Compute an ordering determined by METIS from a VariableIndex
static GTSAM_EXPORT Ordering Metis(const MetisIndex& met); static Ordering Metis(const MetisIndex& met);
template<class FACTOR_GRAPH> template<class FACTOR_GRAPH>
static Ordering Metis(const FACTOR_GRAPH& graph) { static Ordering Metis(const FACTOR_GRAPH& graph) {
@ -243,18 +242,16 @@ public:
/// @name Testable /// @name Testable
/// @{ /// @{
GTSAM_EXPORT
void print(const std::string& str = "", const KeyFormatter& keyFormatter = void print(const std::string& str = "", const KeyFormatter& keyFormatter =
DefaultKeyFormatter) const; DefaultKeyFormatter) const;
GTSAM_EXPORT
bool equals(const Ordering& other, double tol = 1e-9) const; bool equals(const Ordering& other, double tol = 1e-9) const;
/// @} /// @}
private: private:
/// Internal COLAMD function /// Internal COLAMD function
static GTSAM_EXPORT Ordering ColamdConstrained( static Ordering ColamdConstrained(
const VariableIndex& variableIndex, std::vector<int>& cmember); const VariableIndex& variableIndex, std::vector<int>& cmember);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION

View File

@ -104,6 +104,7 @@ class Ordering {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
Ordering(); Ordering();
Ordering(const gtsam::Ordering& other); Ordering(const gtsam::Ordering& other);
Ordering(const std::vector<size_t>& keys);
template < template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
@ -148,7 +149,7 @@ class Ordering {
// Standard interface // Standard interface
size_t size() const; size_t size() const;
size_t at(size_t key) const; size_t at(size_t i) const;
void push_back(size_t key); void push_back(size_t key);
// enabling serialization functionality // enabling serialization functionality
@ -194,4 +195,15 @@ class VariableIndex {
size_t nEntries() const; size_t nEntries() const;
}; };
#include <gtsam/inference/Factor.h>
virtual class Factor {
void print(string s = "Factor\n", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void printKeys(string s = "") const;
bool equals(const gtsam::Factor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -0,0 +1,60 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file inferenceExceptions.cpp
* @brief Exceptions that may be thrown by inference algorithms
* @author Richard Roberts, Varun Agrawal
* @date Apr 25, 2013
*/
#include <gtsam/inference/inferenceExceptions.h>
#include <sstream>
namespace gtsam {
InconsistentEliminationRequested::InconsistentEliminationRequested(
const KeySet& keys, const KeyFormatter& key_formatter)
: keys_(keys.begin(), keys.end()), keyFormatter(key_formatter) {}
const char* InconsistentEliminationRequested::what() const noexcept {
// Format keys for printing
std::stringstream sstr;
size_t nrKeysToDisplay = std::min(size_t(4), keys_.size());
for (size_t i = 0; i < nrKeysToDisplay; i++) {
sstr << keyFormatter(keys_.at(i));
if (i < nrKeysToDisplay - 1) {
sstr << ", ";
}
}
if (keys_.size() > nrKeysToDisplay) {
sstr << ", ... (total " << keys_.size() << " keys)";
}
sstr << ".";
std::string keys = sstr.str();
std::string msg =
"An inference algorithm was called with inconsistent "
"arguments. "
"The\n"
"factor graph, ordering, or variable index were "
"inconsistent with "
"each\n"
"other, or a full elimination routine was called with "
"an ordering "
"that\n"
"does not include all of the variables.\n";
msg += ("Leftover keys after elimination: " + keys);
// `new` to allocate memory on heap instead of stack
return (new std::string(msg))->c_str();
}
} // namespace gtsam

View File

@ -12,30 +12,35 @@
/** /**
* @file inferenceExceptions.h * @file inferenceExceptions.h
* @brief Exceptions that may be thrown by inference algorithms * @brief Exceptions that may be thrown by inference algorithms
* @author Richard Roberts * @author Richard Roberts, Varun Agrawal
* @date Apr 25, 2013 * @date Apr 25, 2013
*/ */
#pragma once #pragma once
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <gtsam/inference/Key.h>
#include <exception> #include <exception>
namespace gtsam { namespace gtsam {
/** An inference algorithm was called with inconsistent arguments. The factor graph, ordering, or /** An inference algorithm was called with inconsistent arguments. The factor
* variable index were inconsistent with each other, or a full elimination routine was called * graph, ordering, or variable index were inconsistent with each other, or a
* with an ordering that does not include all of the variables. */ * full elimination routine was called with an ordering that does not include
class InconsistentEliminationRequested : public std::exception { * all of the variables. */
public: class InconsistentEliminationRequested : public std::exception {
InconsistentEliminationRequested() noexcept {} KeyVector keys_;
~InconsistentEliminationRequested() noexcept override {} const KeyFormatter& keyFormatter = DefaultKeyFormatter;
const char* what() const noexcept override {
return
"An inference algorithm was called with inconsistent arguments. The\n"
"factor graph, ordering, or variable index were inconsistent with each\n"
"other, or a full elimination routine was called with an ordering that\n"
"does not include all of the variables.";
}
};
} public:
InconsistentEliminationRequested() noexcept {}
InconsistentEliminationRequested(
const KeySet& keys,
const KeyFormatter& key_formatter = DefaultKeyFormatter);
~InconsistentEliminationRequested() noexcept override {}
const char* what() const noexcept override;
};
} // namespace gtsam

View File

@ -99,7 +99,7 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const { void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
cout << s << " p("; cout << (s.empty() ? "" : s + " ") << "p(";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << (nrFrontals() > 1 ? " " : ""); cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
} }

View File

@ -261,8 +261,7 @@ class VectorValues {
}; };
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
virtual class GaussianFactor { virtual class GaussianFactor : gtsam::Factor {
gtsam::KeyVector keys() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const;
@ -273,8 +272,6 @@ virtual class GaussianFactor {
Matrix information() const; Matrix information() const;
Matrix augmentedJacobian() const; Matrix augmentedJacobian() const;
pair<Matrix, Vector> jacobian() const; pair<Matrix, Vector> jacobian() const;
size_t size() const;
bool empty() const;
}; };
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
@ -301,10 +298,7 @@ virtual class JacobianFactor : gtsam::GaussianFactor {
//Testable //Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
gtsam::KeyVector& keys() const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const;
size_t size() const;
Vector unweighted_error(const gtsam::VectorValues& c) const; Vector unweighted_error(const gtsam::VectorValues& c) const;
Vector error_vector(const gtsam::VectorValues& c) const; Vector error_vector(const gtsam::VectorValues& c) const;
double error(const gtsam::VectorValues& c) const; double error(const gtsam::VectorValues& c) const;
@ -346,10 +340,8 @@ virtual class HessianFactor : gtsam::GaussianFactor {
HessianFactor(const gtsam::GaussianFactorGraph& factors); HessianFactor(const gtsam::GaussianFactorGraph& factors);
//Testable //Testable
size_t size() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
bool equals(const gtsam::GaussianFactor& lf, double tol) const; bool equals(const gtsam::GaussianFactor& lf, double tol) const;
double error(const gtsam::VectorValues& c) const; double error(const gtsam::VectorValues& c) const;

View File

@ -21,6 +21,7 @@
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/inference/VariableSlots.h> #include <gtsam/inference/VariableSlots.h>
#include <gtsam/inference/VariableIndex.h> #include <gtsam/inference/VariableIndex.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
@ -457,6 +458,64 @@ TEST(GaussianFactorGraph, ProbPrime) {
EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12); EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12);
} }
TEST(GaussianFactorGraph, InconsistentEliminationMessage) {
// Create empty graph
GaussianFactorGraph fg;
SharedDiagonal unit2 = noiseModel::Unit::Create(2);
using gtsam::symbol_shorthand::X;
fg.emplace_shared<JacobianFactor>(0, 10 * I_2x2, -1.0 * Vector::Ones(2),
unit2);
fg.emplace_shared<JacobianFactor>(0, -10 * I_2x2, 1, 10 * I_2x2,
Vector2(2.0, -1.0), unit2);
fg.emplace_shared<JacobianFactor>(1, -5 * I_2x2, 2, 5 * I_2x2,
Vector2(-1.0, 1.5), unit2);
fg.emplace_shared<JacobianFactor>(2, -5 * I_2x2, X(3), 5 * I_2x2,
Vector2(-1.0, 1.5), unit2);
Ordering ordering{0, 1};
try {
fg.eliminateSequential(ordering);
} catch (const exception& exc) {
std::string expected_exception_message = "An inference algorithm was called with inconsistent "
"arguments. "
"The\n"
"factor graph, ordering, or variable index were "
"inconsistent with "
"each\n"
"other, or a full elimination routine was called with "
"an ordering "
"that\n"
"does not include all of the variables.\n"
"Leftover keys after elimination: 2, x3.";
EXPECT(expected_exception_message == exc.what());
}
// Test large number of keys
fg = GaussianFactorGraph();
for (size_t i = 0; i < 1000; i++) {
fg.emplace_shared<JacobianFactor>(i, -I_2x2, i + 1, I_2x2,
Vector2(2.0, -1.0), unit2);
}
try {
fg.eliminateSequential(ordering);
} catch (const exception& exc) {
std::string expected_exception_message = "An inference algorithm was called with inconsistent "
"arguments. "
"The\n"
"factor graph, ordering, or variable index were "
"inconsistent with "
"each\n"
"other, or a full elimination routine was called with "
"an ordering "
"that\n"
"does not include all of the variables.\n"
"Leftover keys after elimination: 2, 3, 4, 5, ... (total 999 keys).";
EXPECT(expected_exception_message == exc.what());
}
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -109,13 +109,10 @@ class NonlinearFactorGraph {
}; };
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>
virtual class NonlinearFactor { virtual class NonlinearFactor : gtsam::Factor {
// Factor base class // Factor base class
size_t size() const;
gtsam::KeyVector keys() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter = void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
void printKeys(string s) const;
// NonlinearFactor // NonlinearFactor
bool equals(const gtsam::NonlinearFactor& other, double tol) const; bool equals(const gtsam::NonlinearFactor& other, double tol) const;
double error(const gtsam::Values& c) const; double error(const gtsam::Values& c) const;

View File

@ -894,6 +894,9 @@ template <size_t d>
std::pair<Values, double> ShonanAveraging<d>::run(const Values &initialEstimate, std::pair<Values, double> ShonanAveraging<d>::run(const Values &initialEstimate,
size_t pMin, size_t pMin,
size_t pMax) const { size_t pMax) const {
if (pMin < d) {
throw std::runtime_error("pMin is smaller than the base dimension d");
}
Values Qstar; Values Qstar;
Values initialSOp = LiftTo<Rot>(pMin, initialEstimate); // lift to pMin! Values initialSOp = LiftTo<Rot>(pMin, initialEstimate); // lift to pMin!
for (size_t p = pMin; p <= pMax; p++) { for (size_t p = pMin; p <= pMax; p++) {

View File

@ -415,6 +415,20 @@ TEST(ShonanAveraging3, PriorWeights) {
auto result = shonan.run(initial, 3, 3); auto result = shonan.run(initial, 3, 3);
EXPECT_DOUBLES_EQUAL(0.0015, shonan.cost(result.first), 1e-4); EXPECT_DOUBLES_EQUAL(0.0015, shonan.cost(result.first), 1e-4);
} }
/* ************************************************************************* */
// Check a small graph created using binary measurements
TEST(ShonanAveraging3, BinaryMeasurements) {
std::vector<BinaryMeasurement<Rot3>> measurements;
auto unit3 = noiseModel::Unit::Create(3);
measurements.emplace_back(0, 1, Rot3::Yaw(M_PI_2), unit3);
measurements.emplace_back(1, 2, Rot3::Yaw(M_PI_2), unit3);
ShonanAveraging3 shonan(measurements);
Values initial = shonan.initializeRandomly();
auto result = shonan.run(initial, 3, 5);
EXPECT_DOUBLES_EQUAL(0.0, shonan.cost(result.first), 1e-4);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -65,4 +65,6 @@ namespace gtsam {
SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree); SymbolicJunctionTree(const SymbolicEliminationTree& eliminationTree);
}; };
/// typedef for wrapper:
using SymbolicCluster = SymbolicJunctionTree::Cluster;
} }

View File

@ -4,7 +4,7 @@
namespace gtsam { namespace gtsam {
#include <gtsam/symbolic/SymbolicFactor.h> #include <gtsam/symbolic/SymbolicFactor.h>
virtual class SymbolicFactor { virtual class SymbolicFactor : gtsam::Factor {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
SymbolicFactor(const gtsam::SymbolicFactor& f); SymbolicFactor(const gtsam::SymbolicFactor& f);
SymbolicFactor(); SymbolicFactor();
@ -18,12 +18,10 @@ virtual class SymbolicFactor {
static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js); static gtsam::SymbolicFactor FromKeys(const gtsam::KeyVector& js);
// From Factor // From Factor
size_t size() const;
void print(string s = "SymbolicFactor", void print(string s = "SymbolicFactor",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicFactor& other, double tol) const; bool equals(const gtsam::SymbolicFactor& other, double tol) const;
gtsam::KeyVector keys();
}; };
#include <gtsam/symbolic/SymbolicFactorGraph.h> #include <gtsam/symbolic/SymbolicFactorGraph.h>
@ -139,7 +137,60 @@ class SymbolicBayesNet {
const gtsam::DotWriter& writer = gtsam::DotWriter()) const; const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicEliminationTree.h>
class SymbolicEliminationTree {
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::VariableIndex& structure,
const gtsam::Ordering& order);
SymbolicEliminationTree(const gtsam::SymbolicFactorGraph& factorGraph,
const gtsam::Ordering& order);
void print(
string name = "EliminationTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::SymbolicEliminationTree& other,
double tol = 1e-9) const;
};
#include <gtsam/symbolic/SymbolicJunctionTree.h>
class SymbolicCluster {
gtsam::Ordering orderedFrontalKeys;
gtsam::SymbolicFactorGraph factors;
const gtsam::SymbolicCluster& operator[](size_t i) const;
size_t nrChildren() const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class SymbolicJunctionTree {
SymbolicJunctionTree(const gtsam::SymbolicEliminationTree& eliminationTree);
void print(
string name = "JunctionTree: ",
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
size_t nrRoots() const;
const gtsam::SymbolicCluster& operator[](size_t i) const;
};
#include <gtsam/symbolic/SymbolicBayesTree.h> #include <gtsam/symbolic/SymbolicBayesTree.h>
class SymbolicBayesTreeClique {
SymbolicBayesTreeClique();
SymbolicBayesTreeClique(const gtsam::SymbolicConditional* conditional);
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter);
const gtsam::SymbolicConditional* conditional() const;
bool isRoot() const;
gtsam::SymbolicBayesTreeClique* parent() const;
size_t treeSize() const;
size_t numCachedSeparatorMarginals() const;
void deleteCachedShortcuts();
};
class SymbolicBayesTree { class SymbolicBayesTree {
// Constructors // Constructors
SymbolicBayesTree(); SymbolicBayesTree();
@ -151,9 +202,14 @@ class SymbolicBayesTree {
bool equals(const gtsam::SymbolicBayesTree& other, double tol) const; bool equals(const gtsam::SymbolicBayesTree& other, double tol) const;
// Standard Interface // Standard Interface
// size_t findParentClique(const gtsam::IndexVector& parents) const; bool empty() const;
size_t size(); size_t size() const;
void saveGraph(string s) const;
const gtsam::SymbolicBayesTreeClique* operator[](size_t j) const;
void saveGraph(string s,
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void clear(); void clear();
void deleteCachedShortcuts(); void deleteCachedShortcuts();
size_t numCachedSeparatorMarginals() const; size_t numCachedSeparatorMarginals() const;
@ -161,28 +217,9 @@ class SymbolicBayesTree {
gtsam::SymbolicConditional* marginalFactor(size_t key) const; gtsam::SymbolicConditional* marginalFactor(size_t key) const;
gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const; gtsam::SymbolicFactorGraph* joint(size_t key1, size_t key2) const;
gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const; gtsam::SymbolicBayesNet* jointBayesNet(size_t key1, size_t key2) const;
};
class SymbolicBayesTreeClique { string dot(const gtsam::KeyFormatter& keyFormatter =
SymbolicBayesTreeClique(); gtsam::DefaultKeyFormatter) const;
// SymbolicBayesTreeClique(gtsam::sharedConditional* conditional);
bool equals(const gtsam::SymbolicBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
size_t numCachedSeparatorMarginals() const;
// gtsam::sharedConditional* conditional() const;
bool isRoot() const;
size_t treeSize() const;
gtsam::SymbolicBayesTreeClique* parent() const;
// // TODO: need wrapped versions graphs, BayesNet
// BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function)
// const; FactorGraph<FactorType> marginal(derived_ptr root, Eliminate
// function) const; FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr
// root, Eliminate function) const;
//
void deleteCachedShortcuts();
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -181,7 +181,7 @@ TEST(QPSolver, iterate) {
QPSolver::State state(currentSolution, VectorValues(), workingSet, false, QPSolver::State state(currentSolution, VectorValues(), workingSet, false,
100); 100);
int it = 0; // int it = 0;
while (!state.converged) { while (!state.converged) {
state = solver.iterate(state); state = solver.iterate(state);
// These checks will fail because the expected solutions obtained from // These checks will fail because the expected solutions obtained from
@ -190,7 +190,7 @@ TEST(QPSolver, iterate) {
// do not recompute dual variables after every step!!! // do not recompute dual variables after every step!!!
// CHECK(assert_equal(expected[it], state.values, 1e-10)); // CHECK(assert_equal(expected[it], state.values, 1e-10));
// CHECK(assert_equal(expectedDuals[it], state.duals, 1e-10)); // CHECK(assert_equal(expectedDuals[it], state.duals, 1e-10));
it++; // it++;
} }
CHECK(assert_equal(expected[3], state.values, 1e-10)); CHECK(assert_equal(expected[3], state.values, 1e-10));

View File

@ -26,7 +26,13 @@ class TestDecisionTreeFactor(GtsamTestCase):
self.B = (5, 2) self.B = (5, 2)
self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6")
def test_from_floats(self):
"""Test whether we can construct a factor from floats."""
actual = DecisionTreeFactor([self.A, self.B], [1., 2., 3., 4., 5., 6.])
self.gtsamAssertEquals(actual, self.factor)
def test_enumerate(self): def test_enumerate(self):
"""Test whether we can enumerate the factor."""
actual = self.factor.enumerate() actual = self.factor.enumerate()
_, values = zip(*actual) _, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])

View File

@ -13,10 +13,15 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, import numpy as np
DiscreteConditional, DiscreteFactorGraph, Ordering) from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph,
DiscreteValues, Ordering)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -27,7 +32,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
# Define DiscreteKey pairs. # Define DiscreteKey pairs.
keys = [(j, 2) for j in range(15)] keys = [(j, 2) for j in range(15)]
# Create thin-tree Bayesnet. # Create thin-tree Bayes net.
bayesNet = DiscreteBayesNet() bayesNet = DiscreteBayesNet()
bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1")
@ -65,15 +70,91 @@ class TestDiscreteBayesNet(GtsamTestCase):
# bayesTree[key].printSignature() # bayesTree[key].printSignature()
# bayesTree.saveGraph("test_DiscreteBayesTree.dot") # bayesTree.saveGraph("test_DiscreteBayesTree.dot")
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
# The root is P( 8 12 14), we can retrieve it by key: # The root is P( 8 12 14), we can retrieve it by key:
root = bayesTree[8] root = bayesTree[8]
self.assertIsInstance(root, DiscreteBayesTreeClique) self.assertIsInstance(root, DiscreteBayesTreeClique)
self.assertTrue(root.isRoot()) self.assertTrue(root.isRoot())
self.assertIsInstance(root.conditional(), DiscreteConditional) self.assertIsInstance(root.conditional(), DiscreteConditional)
# Test all methods in DiscreteBayesTree
self.gtsamAssertEquals(bayesTree, bayesTree)
# Check value at 0
zero_values = DiscreteValues()
for j in range(15):
zero_values[j] = 0
value_at_zeros = bayesTree.evaluate(zero_values)
self.assertAlmostEqual(value_at_zeros, 0.0)
# Check value at max
values_star = factorGraph.optimize()
max_value = bayesTree.evaluate(values_star)
self.assertAlmostEqual(max_value, 0.002548)
# Check operator sugar
max_value = bayesTree(values_star)
self.assertAlmostEqual(max_value, 0.002548)
self.assertFalse(bayesTree.empty())
self.assertEqual(12, bayesTree.size())
def test_discrete_bayes_tree_lookup(self):
"""Check that we can have a multi-frontal lookup table."""
# Make a small planning-like graph: 3 states, 2 actions
graph = DiscreteFactorGraph()
x1, x2, x3 = (X(1), 3), (X(2), 3), (X(3), 3)
a1, a2 = (A(1), 2), (A(2), 2)
# Constraint on start and goal
graph.add([x1], np.array([1, 0, 0]))
graph.add([x3], np.array([0, 0, 1]))
# Should I stay or should I go?
# "Reward" (exp(-cost)) for an action is 10, and rewards multiply:
r = 10
table = np.array([
r, 0, 0, 0, r, 0, # x1 = 0
0, r, 0, 0, 0, r, # x1 = 1
0, 0, r, 0, 0, r # x1 = 2
])
graph.add([x1, a1, x2], table)
graph.add([x2, a2, x3], table)
# Eliminate for MPE (maximum probable explanation).
ordering = Ordering(keys=[A(2), X(3), X(1), A(1), X(2)])
lookup = graph.eliminateMultifrontal(ordering, gtsam.EliminateForMPE)
# Check that the lookup table is correct
assert lookup.size() == 2
lookup_x1_a1_x2 = lookup[X(1)].conditional()
assert lookup_x1_a1_x2.nrFrontals() == 3
# Check that sum is 1.0 (not 100, as we now normalize to prevent underflow)
empty = gtsam.DiscreteValues()
self.assertAlmostEqual(lookup_x1_a1_x2.sum(3)(empty), 1.0)
# And that only non-zero reward is for x1 a1 x2 == 0 1 1
values = DiscreteValues()
values[X(1)] = 0
values[A(1)] = 1
values[X(2)] = 1
self.assertAlmostEqual(lookup_x1_a1_x2(values), 1.0)
lookup_a2_x3 = lookup[X(3)].conditional()
# Check that the sum depends on x2 and is non-zero only for x2 in {1, 2}
sum_x2 = lookup_a2_x3.sum(2)
values = DiscreteValues()
values[X(2)] = 0
self.assertAlmostEqual(sum_x2(values), 0)
values[X(2)] = 1
self.assertAlmostEqual(sum_x2(values), 1.0) # not 10, as we normalize
values[X(2)] = 2
self.assertAlmostEqual(sum_x2(values), 2.0) # not 20, as we normalize
assert lookup_a2_x3.nrFrontals() == 2
# And that the non-zero rewards are for x2 a2 x3 == 1 1 2
values = DiscreteValues()
values[X(2)] = 1
values[A(2)] = 1
values[X(3)] = 2
self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10...
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,14 +10,16 @@ Author: Frank Dellaert
""" """
# pylint: disable=invalid-name, no-name-in-module, no-member # pylint: disable=invalid-name, no-name-in-module, no-member
import math
import unittest import unittest
import numpy as np import numpy as np
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (BetweenFactorPose2, LevenbergMarquardtParams, Pose2, Rot2, from gtsam import (BetweenFactorPose2, BetweenFactorPose3,
ShonanAveraging2, ShonanAveraging3, BinaryMeasurementRot3, LevenbergMarquardtParams, Pose2,
Pose3, Rot2, Rot3, ShonanAveraging2, ShonanAveraging3,
ShonanAveragingParameters2, ShonanAveragingParameters3) ShonanAveragingParameters2, ShonanAveragingParameters3)
DEFAULT_PARAMS = ShonanAveragingParameters3( DEFAULT_PARAMS = ShonanAveragingParameters3(
@ -197,6 +199,19 @@ class TestShonanAveraging(GtsamTestCase):
expected_thetas_deg = np.array([0.0, 90.0, 0.0]) expected_thetas_deg = np.array([0.0, 90.0, 0.0])
np.testing.assert_allclose(thetas_deg, expected_thetas_deg, atol=0.1) np.testing.assert_allclose(thetas_deg, expected_thetas_deg, atol=0.1)
def test_measurements3(self):
"""Create from Measurements."""
measurements = []
unit3 = gtsam.noiseModel.Unit.Create(3)
m01 = BinaryMeasurementRot3(0, 1, Rot3.Yaw(math.radians(90)), unit3)
m12 = BinaryMeasurementRot3(1, 2, Rot3.Yaw(math.radians(90)), unit3)
measurements.append(m01)
measurements.append(m12)
obj = ShonanAveraging3(measurements)
self.assertIsInstance(obj, ShonanAveraging3)
initial = obj.initializeRandomly()
_, cost = obj.run(initial, min_p=3, max_p=5)
self.assertAlmostEqual(cost, 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -84,7 +84,7 @@ class TestVisualISAMExample(GtsamTestCase):
values.insert(key, v) values.insert(key, v)
self.assertAlmostEqual(isam.error(values), 34212421.14731998) self.assertAlmostEqual(isam.error(values), 34212421.14732)
def test_isam2_update(self): def test_isam2_update(self):
""" """