Merge pull request #1501 from borglab/fix-1496

release/4.3a0
Varun Agrawal 2023-11-11 10:29:39 -05:00 committed by GitHub
commit 1121ece0eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 289 additions and 178 deletions

View File

@ -31,6 +31,7 @@ option(GTSAM_FORCE_STATIC_LIB "Force gtsam to be a static library,
option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) option(GTSAM_USE_QUATERNIONS "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) option(GTSAM_POSE3_EXPMAP "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) option(GTSAM_ROT3_EXPMAP "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
option(GTSAM_DT_MERGING "Enable/Disable merging of equal leaf nodes in DecisionTrees. This leads to significant speed up and memory savings." ON)
option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF) option(GTSAM_ENABLE_CONSISTENCY_CHECKS "Enable/Disable expensive consistency checks" OFF)
option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF) option(GTSAM_ENABLE_MEMORY_SANITIZER "Enable/Disable memory sanitizer" OFF)
option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON) option(GTSAM_WITH_TBB "Use Intel Threaded Building Blocks (TBB) if available" ON)

View File

@ -90,6 +90,7 @@ print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS} "Runtime consistency c
print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ") print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER} "Build with Memory Sanitizer ")
print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ") print_enabled_config(${GTSAM_ROT3_EXPMAP} "Rot3 retract is full ExpMap ")
print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ") print_enabled_config(${GTSAM_POSE3_EXPMAP} "Pose3 retract is full ExpMap ")
print_enabled_config(${GTSAM_DT_MERGING} "Enable branch merging in DecisionTree")
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3") print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43} "Allow features deprecated in GTSAM 4.3")
print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ") print_enabled_config(${GTSAM_SUPPORT_NESTED_DISSECTION} "Metis-based Nested Dissection ")
print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration") print_enabled_config(${GTSAM_TANGENT_PREINTEGRATION} "Use tangent-space preintegration")

View File

@ -39,6 +39,9 @@
#cmakedefine GTSAM_ROT3_EXPMAP #cmakedefine GTSAM_ROT3_EXPMAP
#endif #endif
// Whether to enable merging of equal leaf nodes in the Discrete Decision Tree.
#cmakedefine GTSAM_DT_MERGING
// Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake) // Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake)
#cmakedefine GTSAM_USE_TBB #cmakedefine GTSAM_USE_TBB

View File

@ -53,26 +53,17 @@ namespace gtsam {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
/** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;
/// Default constructor for serialization. /// Default constructor for serialization.
Leaf() {} Leaf() {}
/// Constructor from constant /// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1) Leaf(const Y& constant) : constant_(constant) {}
: constant_(constant), nrAssignments_(nrAssignments) {}
/// Return the constant /// Return the constant
const Y& constant() const { const Y& constant() const {
return constant_; return constant_;
} }
/// Return the number of assignments contained within this leaf.
size_t nrAssignments() const { return nrAssignments_; }
/// Leaf-Leaf equality /// Leaf-Leaf equality
bool sameLeaf(const Leaf& q) const override { bool sameLeaf(const Leaf& q) const override {
return constant_ == q.constant_; return constant_ == q.constant_;
@ -93,8 +84,7 @@ namespace gtsam {
/// print /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf [" << nrAssignments() << "] " std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
<< valueFormatter(constant_) << std::endl;
} }
/** Write graphviz format to stream `os`. */ /** Write graphviz format to stream `os`. */
@ -114,14 +104,14 @@ namespace gtsam {
/** apply unary operator */ /** apply unary operator */
NodePtr apply(const Unary& op) const override { NodePtr apply(const Unary& op) const override {
NodePtr f(new Leaf(op(constant_), nrAssignments_)); NodePtr f(new Leaf(op(constant_)));
return f; return f;
} }
/// Apply unary operator with assignment /// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op, NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& assignment) const override { const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); NodePtr f(new Leaf(op(assignment, constant_)));
return f; return f;
} }
@ -137,7 +127,7 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf // Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
// fL op gL // fL op gL
NodePtr h(new Leaf(op(fL.constant_, constant_), nrAssignments_)); NodePtr h(new Leaf(op(fL.constant_, constant_)));
return h; return h;
} }
@ -148,7 +138,7 @@ namespace gtsam {
/** choose a branch, create new memory ! */ /** choose a branch, create new memory ! */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
return NodePtr(new Leaf(constant(), nrAssignments())); return NodePtr(new Leaf(constant()));
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
@ -163,7 +153,6 @@ namespace gtsam {
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_); ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
} }
#endif #endif
}; // Leaf }; // Leaf
@ -199,26 +188,50 @@ namespace gtsam {
#endif #endif
} }
/// If all branches of a choice node f are the same, just return a branch. /**
static NodePtr Unique(const ChoicePtr& f) { * @brief Merge branches with equal leaf values for every choice node in a
#ifndef GTSAM_DT_NO_PRUNING * decision tree. If all branches are the same (i.e. have the same leaf
* value), replace the choice node with the equivalent leaf node.
*
* This function applies the branch merging (if enabled) recursively on the
* decision tree represented by the root node passed in as the argument. It
* recurses to the leaf nodes and merges branches with equal leaf values in
* a bottom-up fashion.
*
* Thus, if all branches of a choice node `f` are the same,
* just return a single branch at each recursion step.
*
* @param node The root node of the decision tree.
* @return NodePtr
*/
static NodePtr Unique(const NodePtr& node) {
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
// Choice node, we recurse!
// Make non-const copy so we can update
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
// Iterate over all the branches
for (size_t i = 0; i < choice->nrChoices(); i++) {
auto branch = choice->branches_[i];
f->push_back(Unique(branch));
}
#ifdef GTSAM_DT_MERGING
// If all the branches are the same, we can merge them into one
if (f->allSame_) { if (f->allSame_) {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
size_t nrAssignments = 0;
for(auto branch: f->branches()) {
assert(branch->isLeaf());
nrAssignments +=
std::dynamic_pointer_cast<const Leaf>(branch)->nrAssignments();
}
NodePtr newLeaf( NodePtr newLeaf(
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant(), new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
nrAssignments));
return newLeaf; return newLeaf;
} else }
#endif #endif
return f; return f;
} else {
// Leaf node, return as is
return node;
}
} }
bool isLeaf() const override { return false; } bool isLeaf() const override { return false; }
@ -439,8 +452,10 @@ namespace gtsam {
// second case, not label of interest, just recurse // second case, not label of interest, just recurse
auto r = std::make_shared<Choice>(label_, branches_.size()); auto r = std::make_shared<Choice>(label_, branches_.size());
for (auto&& branch : branches_) for (auto&& branch : branches_) {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
}
return Unique(r); return Unique(r);
} }
@ -464,13 +479,11 @@ namespace gtsam {
// DecisionTree // DecisionTree
/****************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() { DecisionTree<L, Y>::DecisionTree() {}
}
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) : DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
root_(root) { root_(root) {}
}
/****************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
@ -586,7 +599,8 @@ namespace gtsam {
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin); auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
for (Iterator it = begin; it != end; it++) for (Iterator it = begin; it != end; it++)
choiceOnLabel->push_back(it->root_); choiceOnLabel->push_back(it->root_);
return Choice::Unique(choiceOnLabel); // If no reordering, no need to call Choice::Unique
return choiceOnLabel;
} else { } else {
// Set up a new choice on the highest label // Set up a new choice on the highest label
auto choiceOnHighestLabel = auto choiceOnHighestLabel =
@ -605,21 +619,21 @@ namespace gtsam {
NodePtr fi = compose(functions.begin(), functions.end(), label); NodePtr fi = compose(functions.begin(), functions.end(), label);
choiceOnHighestLabel->push_back(fi); choiceOnHighestLabel->push_back(fi);
} }
return Choice::Unique(choiceOnHighestLabel); return choiceOnHighestLabel;
} }
} }
/****************************************************************************/ /****************************************************************************/
// "create" is a bit of a complicated thing, but very useful. // "build" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values, // It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows: // and builds a decision tree, as follows:
// - if there is only one label, creates a choice node with values in leaves // - if there is only one label, creates a choice node with values in leaves
// - otherwise, it evenly splits up the range of values and creates a tree for // - otherwise, it evenly splits up the range of values and creates a tree for
// each sub-range, and assigns that tree to first label's choices // each sub-range, and assigns that tree to first label's choices
// Example: // Example:
// create([B A],[1 2 3 4]) would call // build([B A],[1 2 3 4]) would call
// create([A],[1 2]) // build([A],[1 2])
// create([A],[3 4]) // build([A],[3 4])
// and produce // and produce
// B=0 // B=0
// A=0: 1 // A=0: 1
@ -632,7 +646,7 @@ namespace gtsam {
// However, it will be *way* faster if labels are given highest to lowest. // However, it will be *way* faster if labels are given highest to lowest.
template<typename L, typename Y> template<typename L, typename Y>
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
It begin, It end, ValueIt beginY, ValueIt endY) const { It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts // get crucial counts
size_t nrChoices = begin->second; size_t nrChoices = begin->second;
@ -650,9 +664,10 @@ namespace gtsam {
throw std::invalid_argument("DecisionTree::create invalid argument"); throw std::invalid_argument("DecisionTree::create invalid argument");
} }
auto choice = std::make_shared<Choice>(begin->first, endY - beginY); auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
for (ValueIt y = beginY; y != endY; y++) for (ValueIt y = beginY; y != endY; y++) {
choice->push_back(NodePtr(new Leaf(*y))); choice->push_back(NodePtr(new Leaf(*y)));
return Choice::Unique(choice); }
return choice;
} }
// Recursive case: perform "Shannon expansion" // Recursive case: perform "Shannon expansion"
@ -661,12 +676,27 @@ namespace gtsam {
std::vector<DecisionTree> functions; std::vector<DecisionTree> functions;
size_t split = size / nrChoices; size_t split = size / nrChoices;
for (size_t i = 0; i < nrChoices; i++, beginY += split) { for (size_t i = 0; i < nrChoices; i++, beginY += split) {
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split); NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
functions.emplace_back(f); functions.emplace_back(f);
} }
return compose(functions.begin(), functions.end(), begin->first); return compose(functions.begin(), functions.end(), begin->first);
} }
/****************************************************************************/
// Top-level factory method, which takes a range of labels and a corresponding
// range of values, and creates a decision tree.
template<typename L, typename Y>
template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const {
auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
return Choice::Unique(choice);
} else {
return node;
}
}
/****************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
@ -681,7 +711,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf. // If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) { if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant()), leaf->nrAssignments())); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
} }
// Check if Choice // Check if Choice
@ -699,7 +729,7 @@ namespace gtsam {
for (auto&& branch : choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
} }
/****************************************************************************/ /****************************************************************************/
@ -828,16 +858,6 @@ namespace gtsam {
return total; return total;
} }
/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}
/****************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>

View File

@ -154,7 +154,15 @@ namespace gtsam {
* Internal recursive function to create from keys, cardinalities, * Internal recursive function to create from keys, cardinalities,
* and Y values * and Y values
*/ */
template<typename It, typename ValueIt> template <typename It, typename ValueIt>
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;
/** Internal helper function to create from
* keys, cardinalities, and Y values.
* Calls `build` which builds thetree bottom-up,
* before we prune in a top-down fashion.
*/
template <typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
/** /**
@ -320,42 +328,6 @@ namespace gtsam {
/// Return the number of leaves in the tree. /// Return the number of leaves in the tree.
size_t nrLeaves() const; size_t nrLeaves() const;
/**
* @brief This is a convenience function which returns the total number of
* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;
/** /**
* @brief Fold a binary function over the tree, returning accumulator. * @brief Fold a binary function over the tree, returning accumulator.
* *

View File

@ -20,7 +20,6 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define GTSAM_DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
@ -179,7 +178,11 @@ TEST(ADT, joint) {
dot(joint, "Asia-ASTLBEX"); dot(joint, "Asia-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Asia-ASTLBEXD"); dot(joint, "Asia-ASTLBEXD");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(346, muls); EXPECT_LONGS_EQUAL(346, muls);
#else
EXPECT_LONGS_EQUAL(508, muls);
#endif
gttoc_(asiaJoint); gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint); tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall(); elapsed = asiaJointNode->secs() + asiaJointNode->wall();
@ -240,7 +243,11 @@ TEST(ADT, inference) {
dot(joint, "Joint-Product-ASTLBEX"); dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD"); dot(joint, "Joint-Product-ASTLBEXD");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
#else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif
gttoc_(asiaProd); gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd); tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall(); elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -258,7 +265,11 @@ TEST(ADT, inference) {
dot(marginal, "Joint-Sum-ADBLE"); dot(marginal, "Joint-Sum-ADBLE");
marginal = marginal.combine(E, &add_); marginal = marginal.combine(E, &add_);
dot(marginal, "Joint-Sum-ADBL"); dot(marginal, "Joint-Sum-ADBL");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(161, (long)adds); EXPECT_LONGS_EQUAL(161, (long)adds);
#else
EXPECT_LONGS_EQUAL(240, (long)adds);
#endif
gttoc_(asiaSum); gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum); tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall(); elapsed = asiaSumNode->secs() + asiaSumNode->wall();
@ -296,7 +307,11 @@ TEST(ADT, factor_graph) {
fg = apply(fg, pX, &mul); fg = apply(fg, pX, &mul);
fg = apply(fg, pD, &mul); fg = apply(fg, pD, &mul);
dot(fg, "FactorGraph"); dot(fg, "FactorGraph");
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(158, (long)muls); EXPECT_LONGS_EQUAL(158, (long)muls);
#else
EXPECT_LONGS_EQUAL(188, (long)muls);
#endif
gttoc_(asiaFG); gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG); tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall(); elapsed = asiaFGNode->secs() + asiaFGNode->wall();
@ -315,7 +330,11 @@ TEST(ADT, factor_graph) {
dot(fg, "Marginalized-3E"); dot(fg, "Marginalized-3E");
fg = fg.combine(L, &add_); fg = fg.combine(L, &add_);
dot(fg, "Marginalized-2L"); dot(fg, "Marginalized-2L");
#ifdef GTSAM_DT_MERGING
LONGS_EQUAL(49, adds); LONGS_EQUAL(49, adds);
#else
LONGS_EQUAL(62, adds);
#endif
gttoc_(marg); gttoc_(marg);
tictoc_getNode(margNode, marg); tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall(); elapsed = margNode->secs() + margNode->wall();

View File

@ -18,7 +18,6 @@
*/ */
// #define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
@ -234,7 +233,11 @@ TEST(DecisionTree, Example) {
// Test choose 0 // Test choose 0
DT actual0 = notba.choose(A, 0); DT actual0 = notba.choose(A, 0);
#ifdef GTSAM_DT_MERGING
EXPECT(assert_equal(DT(0.0), actual0)); EXPECT(assert_equal(DT(0.0), actual0));
#else
EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
#endif
DOT(actual0); DOT(actual0);
// Test choose 1 // Test choose 1
@ -367,49 +370,6 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ************************************************************************** */
// Test nrAssignments.
TEST(DecisionTree, NrAssignments) {
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");
EXPECT(tree.root_->isLeaf());
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
DT tree2({C, B, A}, "1 1 1 2 3 4 5 5");
/* The tree is
Choice(C)
0 Choice(B)
0 0 Leaf 1
0 1 Choice(A)
0 1 0 Leaf 1
0 1 1 Leaf 2
1 Choice(B)
1 0 Choice(A)
1 0 0 Leaf 3
1 0 1 Leaf 4
1 1 Leaf 5
*/
auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
CHECK(root);
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
CHECK(choice0);
EXPECT(choice0->branches()[0]->isLeaf());
auto choice00 = std::dynamic_pointer_cast<const DT::Leaf>(choice0->branches()[0]);
CHECK(choice00);
EXPECT_LONGS_EQUAL(2, choice00->nrAssignments());
auto choice1 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[1]);
CHECK(choice1);
auto choice10 = std::dynamic_pointer_cast<const DT::Choice>(choice1->branches()[0]);
CHECK(choice10);
auto choice11 = std::dynamic_pointer_cast<const DT::Leaf>(choice1->branches()[1]);
CHECK(choice11);
EXPECT(choice11->isLeaf());
EXPECT_LONGS_EQUAL(2, choice11->nrAssignments());
}
/* ************************************************************************** */ /* ************************************************************************** */
// Test visit. // Test visit.
TEST(DecisionTree, visit) { TEST(DecisionTree, visit) {
@ -449,10 +409,15 @@ TEST(DecisionTree, VisitWithPruned) {
}; };
tree.visitWith(func); tree.visitWith(func);
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(6, choices.size()); EXPECT_LONGS_EQUAL(6, choices.size());
#else
EXPECT_LONGS_EQUAL(8, choices.size());
#endif
Assignment<string> expectedAssignment; Assignment<string> expectedAssignment;
#ifdef GTSAM_DT_MERGING
expectedAssignment = {{"B", 0}, {"C", 0}}; expectedAssignment = {{"B", 0}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(0)); EXPECT(expectedAssignment == choices.at(0));
@ -470,6 +435,25 @@ TEST(DecisionTree, VisitWithPruned) {
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}}; expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
EXPECT(expectedAssignment == choices.at(5)); EXPECT(expectedAssignment == choices.at(5));
#else
expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(0));
expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(1));
expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(2));
expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(3));
expectedAssignment = {{"A", 0}, {"B", 0}, {"C", 1}};
EXPECT(expectedAssignment == choices.at(4));
expectedAssignment = {{"A", 1}, {"B", 0}, {"C", 1}};
EXPECT(expectedAssignment == choices.at(5));
#endif
} }
/* ************************************************************************** */ /* ************************************************************************** */
@ -480,7 +464,11 @@ TEST(DecisionTree, fold) {
DT tree(B, DT(A, 1, 1), DT(A, 2, 3)); DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; }; auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0); double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning! #ifdef GTSAM_DT_MERGING
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to merging!
#else
EXPECT_DOUBLES_EQUAL(7.0, sum, 1e-9);
#endif
} }
/* ************************************************************************** */ /* ************************************************************************** */
@ -532,9 +520,14 @@ TEST(DecisionTree, threshold) {
auto threshold = [](int value) { return value < 5 ? 0 : value; }; auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold); DT thresholded(tree, threshold);
#ifdef GTSAM_DT_MERGING
// Check number of leaves equal to zero now = 2 // Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1! // Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0)); EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
#else
// if GTSAM_DT_MERGING is disabled, the count will be larger
EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
#endif
} }
/* ************************************************************************** */ /* ************************************************************************** */
@ -570,8 +563,13 @@ TEST(DecisionTree, ApplyWithAssignment) {
}; };
DT prunedTree2 = prunedTree.apply(counter); DT prunedTree2 = prunedTree.apply(counter);
#ifdef GTSAM_DT_MERGING
// Check if apply doesn't enumerate all leaves. // Check if apply doesn't enumerate all leaves.
EXPECT_LONGS_EQUAL(5, count); EXPECT_LONGS_EQUAL(5, count);
#else
// if GTSAM_DT_MERGING is disabled, the count will be full
EXPECT_LONGS_EQUAL(8, count);
#endif
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -22,6 +22,8 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
using namespace std; using namespace std;

View File

@ -15,17 +15,20 @@
* @author Duy-Nguyen Ta * @author Duy-Nguyen Ta
*/ */
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/Symbol.h>
#include <CppUnitLite/TestHarness.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using symbol_shorthand::M;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
@ -345,6 +348,7 @@ TEST(DiscreteFactorGraph, markdown) {
values[1] = 0; values[1] = 0;
EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.3, graph[0]->operator()(values), 1e-9);
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -70,8 +70,7 @@ Ordering HybridGaussianISAM::GetOrdering(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridGaussianISAM::updateInternal( void HybridGaussianISAM::updateInternal(
const HybridGaussianFactorGraph& newFactors, const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans, HybridBayesTree::Cliques* orphans, const std::optional<size_t>& maxNrLeaves,
const std::optional<size_t>& maxNrLeaves,
const std::optional<Ordering>& ordering, const std::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) { const HybridBayesTree::Eliminate& function) {
// Remove the contaminated part of the Bayes tree // Remove the contaminated part of the Bayes tree
@ -101,8 +100,8 @@ void HybridGaussianISAM::updateInternal(
} }
// eliminate all factors (top, added, orphans) into a new Bayes tree // eliminate all factors (top, added, orphans) into a new Bayes tree
HybridBayesTree::shared_ptr bayesTree = HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(
factors.eliminateMultifrontal(elimination_ordering, function, std::cref(index)); elimination_ordering, function, std::cref(index));
if (maxNrLeaves) { if (maxNrLeaves) {
bayesTree->prune(*maxNrLeaves); bayesTree->prune(*maxNrLeaves);

View File

@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
std::string expected = std::string expected =
R"(Hybrid [x1 x2; 1]{ R"(Hybrid [x1 x2; 1]{
Choice(1) Choice(1)
0 Leaf [1] : 0 Leaf :
A[x1] = [ A[x1] = [
0; 0;
0 0
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
b = [ 0 0 ] b = [ 0 0 ]
No noise model No noise model
1 Leaf [1] : 1 Leaf :
A[x1] = [ A[x1] = [
0; 0;
0 0

View File

@ -296,8 +296,12 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
std::make_shared<DecisionTreeFactor>( std::make_shared<DecisionTreeFactor>(
discreteConditionals.prune(maxNrLeaves)); discreteConditionals.prune(maxNrLeaves));
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves()); prunedDecisionTree->nrLeaves());
#else
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves());
#endif
// regression // regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};

View File

@ -481,6 +481,7 @@ TEST(HybridFactorGraph, Printing) {
const auto [hybridBayesNet, remainingFactorGraph] = const auto [hybridBayesNet, remainingFactorGraph] =
linearizedFactorGraph.eliminatePartialSequential(ordering); linearizedFactorGraph.eliminatePartialSequential(ordering);
#ifdef GTSAM_DT_MERGING
string expected_hybridFactorGraph = R"( string expected_hybridFactorGraph = R"(
size: 7 size: 7
factor 0: factor 0:
@ -492,7 +493,7 @@ factor 0:
factor 1: factor 1:
Hybrid [x0 x1; m0]{ Hybrid [x0 x1; m0]{
Choice(m0) Choice(m0)
0 Leaf [1] : 0 Leaf :
A[x0] = [ A[x0] = [
-1 -1
] ]
@ -502,7 +503,7 @@ Hybrid [x0 x1; m0]{
b = [ -1 ] b = [ -1 ]
No noise model No noise model
1 Leaf [1] : 1 Leaf :
A[x0] = [ A[x0] = [
-1 -1
] ]
@ -516,7 +517,7 @@ Hybrid [x0 x1; m0]{
factor 2: factor 2:
Hybrid [x1 x2; m1]{ Hybrid [x1 x2; m1]{
Choice(m1) Choice(m1)
0 Leaf [1] : 0 Leaf :
A[x1] = [ A[x1] = [
-1 -1
] ]
@ -526,7 +527,7 @@ Hybrid [x1 x2; m1]{
b = [ -1 ] b = [ -1 ]
No noise model No noise model
1 Leaf [1] : 1 Leaf :
A[x1] = [ A[x1] = [
-1 -1
] ]
@ -550,18 +551,104 @@ factor 4:
b = [ -10 ] b = [ -10 ]
No noise model No noise model
factor 5: P( m0 ): factor 5: P( m0 ):
Leaf [2] 0.5 Leaf 0.5
factor 6: P( m1 | m0 ): factor 6: P( m1 | m0 ):
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf [1] 0.33333333 0 0 Leaf 0.33333333
0 1 Leaf [1] 0.6 0 1 Leaf 0.6
1 Choice(m0) 1 Choice(m0)
1 0 Leaf [1] 0.66666667 1 0 Leaf 0.66666667
1 1 Leaf [1] 0.4 1 1 Leaf 0.4
)"; )";
#else
string expected_hybridFactorGraph = R"(
size: 7
factor 0:
A[x0] = [
10
]
b = [ -10 ]
No noise model
factor 1:
Hybrid [x0 x1; m0]{
Choice(m0)
0 Leaf:
A[x0] = [
-1
]
A[x1] = [
1
]
b = [ -1 ]
No noise model
1 Leaf:
A[x0] = [
-1
]
A[x1] = [
1
]
b = [ -0 ]
No noise model
}
factor 2:
Hybrid [x1 x2; m1]{
Choice(m1)
0 Leaf:
A[x1] = [
-1
]
A[x2] = [
1
]
b = [ -1 ]
No noise model
1 Leaf:
A[x1] = [
-1
]
A[x2] = [
1
]
b = [ -0 ]
No noise model
}
factor 3:
A[x1] = [
10
]
b = [ -10 ]
No noise model
factor 4:
A[x2] = [
10
]
b = [ -10 ]
No noise model
factor 5: P( m0 ):
Choice(m0)
0 Leaf 0.5
1 Leaf 0.5
factor 6: P( m1 | m0 ):
Choice(m1)
0 Choice(m0)
0 0 Leaf 0.33333333
0 1 Leaf 0.6
1 Choice(m0)
1 0 Leaf 0.66666667
1 1 Leaf 0.4
)";
#endif
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
// Expected output for hybridBayesNet. // Expected output for hybridBayesNet.
@ -570,13 +657,13 @@ size: 3
conditional 0: Hybrid P( x0 | x1 m0) conditional 0: Hybrid P( x0 | x1 m0)
Discrete Keys = (m0, 2), Discrete Keys = (m0, 2),
Choice(m0) Choice(m0)
0 Leaf [1] p(x0 | x1) 0 Leaf p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
S[x1] = [ -0.0995037 ] S[x1] = [ -0.0995037 ]
d = [ -9.85087 ] d = [ -9.85087 ]
No noise model No noise model
1 Leaf [1] p(x0 | x1) 1 Leaf p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
S[x1] = [ -0.0995037 ] S[x1] = [ -0.0995037 ]
d = [ -9.95037 ] d = [ -9.95037 ]
@ -586,26 +673,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf [1] p(x1 | x2) 0 0 Leaf p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -9.99901 ] d = [ -9.99901 ]
No noise model No noise model
0 1 Leaf [1] p(x1 | x2) 0 1 Leaf p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -9.90098 ] d = [ -9.90098 ]
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf [1] p(x1 | x2) 1 0 Leaf p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -10.098 ] d = [ -10.098 ]
No noise model No noise model
1 1 Leaf [1] p(x1 | x2) 1 1 Leaf p(x1 | x2)
R = [ 10.099 ] R = [ 10.099 ]
S[x2] = [ -0.0990196 ] S[x2] = [ -0.0990196 ]
d = [ -10 ] d = [ -10 ]
@ -615,14 +702,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf [1] p(x2) 0 0 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1489 ] d = [ -10.1489 ]
mean: 1 elements mean: 1 elements
x2: -1.0099 x2: -1.0099
No noise model No noise model
0 1 Leaf [1] p(x2) 0 1 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1479 ] d = [ -10.1479 ]
mean: 1 elements mean: 1 elements
@ -630,14 +717,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf [1] p(x2) 1 0 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0504 ] d = [ -10.0504 ]
mean: 1 elements mean: 1 elements
x2: -1.0001 x2: -1.0001
No noise model No noise model
1 1 Leaf [1] p(x2) 1 1 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0494 ] d = [ -10.0494 ]
mean: 1 elements mean: 1 elements

View File

@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
R"(Hybrid [x1 x2; 1] R"(Hybrid [x1 x2; 1]
MixtureFactor MixtureFactor
Choice(1) Choice(1)
0 Leaf [1] Nonlinear factor on 2 keys 0 Leaf Nonlinear factor on 2 keys
1 Leaf [1] Nonlinear factor on 2 keys 1 Leaf Nonlinear factor on 2 keys
)"; )";
EXPECT(assert_print_equal(expected, mixtureFactor)); EXPECT(assert_print_equal(expected, mixtureFactor));
} }

View File

@ -10,10 +10,11 @@ Author: Fan Jiang
""" """
import unittest import unittest
import gtsam
import numpy as np import numpy as np
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
import gtsam
class TestRobust(GtsamTestCase): class TestRobust(GtsamTestCase):
@ -37,7 +38,7 @@ class TestRobust(GtsamTestCase):
v = gtsam.Values() v = gtsam.Values()
v.insert(0, 0.0) v.insert(0, 0.0)
self.assertAlmostEquals(f.error(v), 0.125) self.assertAlmostEqual(f.error(v), 0.125)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()