Merge pull request #1881 from borglab/feature/no_conditionals

Significant speedup
release/4.3a0
Frank Dellaert 2024-10-23 10:36:08 -07:00 committed by GitHub
commit 366b51432a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 292 additions and 116 deletions

View File

@ -29,6 +29,7 @@
#include <optional> #include <optional>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <stdexcept>
#include <string> #include <string>
#include <vector> #include <vector>
@ -286,6 +287,10 @@ namespace gtsam {
return branches_; return branches_;
} }
std::vector<NodePtr>& branches() {
return branches_;
}
/** add a branch: TODO merge into constructor */ /** add a branch: TODO merge into constructor */
void push_back(NodePtr&& node) { void push_back(NodePtr&& node) {
// allSame_ is restricted to leaf nodes in a decision tree // allSame_ is restricted to leaf nodes in a decision tree
@ -482,8 +487,8 @@ namespace gtsam {
/****************************************************************************/ /****************************************************************************/
// DecisionTree // DecisionTree
/****************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() {} DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) : DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
@ -554,6 +559,36 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }
/****************************************************************************/
template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Unary& op,
DecisionTree&& other) noexcept
: root_(std::move(other.root_)) {
// Apply the unary operation directly to each leaf in the tree
if (root_) {
// Define a helper function to traverse and apply the operation
struct ApplyUnary {
const Unary& op;
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
// Apply the unary operation to the leaf's constant value
leaf->constant_ = op(leaf->constant_);
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
// Recurse into the choice branches
for (NodePtr& branch : choice->branches()) {
(*this)(branch);
}
}
}
};
ApplyUnary applyUnary{op};
applyUnary(root_);
}
// Reset the other tree's root to nullptr to avoid dangling references
other.root_ = nullptr;
}
/****************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename X, typename Func> template <typename X, typename Func>
@ -694,7 +729,7 @@ namespace gtsam {
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) { It begin, It end, ValueIt beginY, ValueIt endY) {
auto node = build(begin, end, beginY, endY); auto node = build(begin, end, beginY, endY);
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) { if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
return Choice::Unique(choice); return Choice::Unique(choice);
} else { } else {
return node; return node;
@ -710,7 +745,7 @@ namespace gtsam {
// If leaf, apply unary conversion "op" and create a unique leaf. // If leaf, apply unary conversion "op" and create a unique leaf.
using LXLeaf = typename DecisionTree<L, X>::Leaf; using LXLeaf = typename DecisionTree<L, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) { if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
} }
@ -951,11 +986,16 @@ namespace gtsam {
return root_->equals(*other.root_); return root_->equals(*other.root_);
} }
/****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const { const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
if (root_ == nullptr)
throw std::invalid_argument(
"DecisionTree::operator() called on empty tree");
return root_->operator ()(x); return root_->operator ()(x);
} }
/****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const { DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
// It is unclear what should happen if tree is empty: // It is unclear what should happen if tree is empty:
@ -966,6 +1006,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
/****************************************************************************/
/// Apply unary operator with assignment /// Apply unary operator with assignment
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply( DecisionTree<L, Y> DecisionTree<L, Y>::apply(
@ -1049,6 +1090,18 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/******************************************************************************/ /******************************************************************************/
template <typename L, typename Y>
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const {
using AB = std::pair<A, B>;
const DecisionTree<L, AB> ab(*this, AB_of_Y);
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
return {a, b};
}
/******************************************************************************/
} // namespace gtsam } // namespace gtsam

View File

@ -85,7 +85,7 @@ namespace gtsam {
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
struct Node { struct Node {
using Ptr = std::shared_ptr<const Node>; using Ptr = std::shared_ptr<Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
static int nrNodes; static int nrNodes;
@ -156,10 +156,10 @@ namespace gtsam {
template <typename It, typename ValueIt> template <typename It, typename ValueIt>
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
/** Internal helper function to create from /**
* keys, cardinalities, and Y values. * Internal helper function to create a tree from keys, cardinalities, and Y
* Calls `build` which builds thetree bottom-up, * values. Calls `build` which builds the tree bottom-up, before we prune in
* before we prune in a top-down fashion. * a top-down fashion.
*/ */
template <typename It, typename ValueIt> template <typename It, typename ValueIt>
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
@ -228,6 +228,15 @@ namespace gtsam {
DecisionTree(const L& label, const DecisionTree& f0, DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f1); const DecisionTree& f1);
/**
* @brief Move constructor for DecisionTree. Very efficient as does not
* allocate anything, just changes in-place. But `other` is consumed.
*
* @param op The unary operation to apply to the moved DecisionTree.
* @param other The DecisionTree to move from, will be empty afterwards.
*/
DecisionTree(const Unary& op, DecisionTree&& other) noexcept;
/** /**
* @brief Convert from a different value type. * @brief Convert from a different value type.
* *
@ -239,7 +248,7 @@ namespace gtsam {
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X); DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
/** /**
* @brief Convert from a different value type X to value type Y, also transate * @brief Convert from a different value type X to value type Y, also translate
* labels via map from type M to L. * labels via map from type M to L.
* *
* @tparam M Previous label type. * @tparam M Previous label type.
@ -406,6 +415,18 @@ namespace gtsam {
const ValueFormatter& valueFormatter, const ValueFormatter& valueFormatter,
bool showZero = true) const; bool showZero = true) const;
/**
* @brief Convert into two trees with value types A and B.
*
* @tparam A First new value type.
* @tparam B Second new value type.
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
* @return A pair of DecisionTrees with value types A and B respectively.
*/
template <typename A, typename B>
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const;
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -11,7 +11,7 @@
/* /*
* @file testDecisionTree.cpp * @file testDecisionTree.cpp
* @brief Develop DecisionTree * @brief DecisionTree unit tests
* @author Frank Dellaert * @author Frank Dellaert
* @author Can Erdogan * @author Can Erdogan
* @date Jan 30, 2012 * @date Jan 30, 2012
@ -108,6 +108,7 @@ struct DT : public DecisionTree<string, int> {
std::cout << s; std::cout << s;
Base::print("", keyFormatter, valueFormatter); Base::print("", keyFormatter, valueFormatter);
} }
/// Equality method customized to int node type /// Equality method customized to int node type
bool equals(const Base& other, double tol = 1e-9) const { bool equals(const Base& other, double tol = 1e-9) const {
auto compare = [](const int& v, const int& w) { return v == w; }; auto compare = [](const int& v, const int& w) { return v == w; };
@ -271,6 +272,58 @@ TEST(DecisionTree, Example) {
DOT(acnotb); DOT(acnotb);
} }
/* ************************************************************************** */
// Test that we can create two trees out of one, using a function that returns a pair.
TEST(DecisionTree, Split) {
// Create labels
string A("A"), B("B");
// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
// Define a function that returns an int/bool pair
auto split_function = [](const int& value) -> std::pair<int, bool> {
return {value*3, value*3 % 2 == 0};
};
// Split the original tree into two new trees
auto [la,lb] = original.split<int,bool>(split_function);
// Check the first resulting tree
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));
// Check the second resulting tree
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
}
/* ************************************************************************** */
// Test that we can create a tree by modifying an rvalue.
TEST(DecisionTree, Consume) {
// Create labels
string A("A"), B("B");
// Create a decision tree
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
DT modified([](int i){return i*2;}, std::move(original));
// Check the first resulting tree
EXPECT_LONGS_EQUAL(2, modified(Assignment<string>{{A, 0}, {B, 0}}));
EXPECT_LONGS_EQUAL(4, modified(Assignment<string>{{A, 0}, {B, 1}}));
EXPECT_LONGS_EQUAL(6, modified(Assignment<string>{{A, 1}, {B, 0}}));
EXPECT_LONGS_EQUAL(8, modified(Assignment<string>{{A, 1}, {B, 1}}));
// Check original was moved
EXPECT(original.root_ == nullptr);
}
/* ************************************************************************** */ /* ************************************************************************** */
// test Conversion of values // test Conversion of values
bool bool_of_int(const int& y) { return y != 0; }; bool bool_of_int(const int& y) { return y != 0; };

View File

@ -25,12 +25,27 @@
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
#include <cstddef> #include <cstddef>
#include <memory>
namespace gtsam { namespace gtsam {
/* *******************************************************************************/
GaussianConditional::shared_ptr checkConditional(
const GaussianFactor::shared_ptr &factor) {
if (auto conditional =
std::dynamic_pointer_cast<GaussianConditional>(factor)) {
return conditional;
} else {
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
}
}
/* *******************************************************************************/ /* *******************************************************************************/
/** /**
* @brief Helper struct for constructing HybridGaussianConditional objects * @brief Helper struct for constructing HybridGaussianConditional objects
@ -38,15 +53,13 @@ namespace gtsam {
* This struct contains the following fields: * This struct contains the following fields:
* - nrFrontals: Optional size_t for number of frontal variables * - nrFrontals: Optional size_t for number of frontal variables
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant * - pairs: FactorValuePairs for storing conditionals with their negLogConstant
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in * - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
* constructor * constructor
*/ */
struct HybridGaussianConditional::Helper { struct HybridGaussianConditional::Helper {
std::optional<size_t> nrFrontals;
FactorValuePairs pairs; FactorValuePairs pairs;
Conditionals conditionals; std::optional<size_t> nrFrontals = {};
double minNegLogConstant; double minNegLogConstant = std::numeric_limits<double>::infinity();
using GC = GaussianConditional; using GC = GaussianConditional;
using P = std::vector<std::pair<Vector, double>>; using P = std::vector<std::pair<Vector, double>>;
@ -55,8 +68,6 @@ struct HybridGaussianConditional::Helper {
template <typename... Args> template <typename... Args>
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) { explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
nrFrontals = 1; nrFrontals = 1;
minNegLogConstant = std::numeric_limits<double>::infinity();
std::vector<GaussianFactorValuePair> fvs; std::vector<GaussianFactorValuePair> fvs;
std::vector<GC::shared_ptr> gcs; std::vector<GC::shared_ptr> gcs;
fvs.reserve(p.size()); fvs.reserve(p.size());
@ -70,14 +81,11 @@ struct HybridGaussianConditional::Helper {
gcs.push_back(gaussianConditional); gcs.push_back(gaussianConditional);
} }
conditionals = Conditionals({mode}, gcs);
pairs = FactorValuePairs({mode}, fvs); pairs = FactorValuePairs({mode}, fvs);
} }
/// Construct from tree of GaussianConditionals. /// Construct from tree of GaussianConditionals.
explicit Helper(const Conditionals &conditionals) explicit Helper(const Conditionals &conditionals) {
: conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()}; if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals(); if (!nrFrontals) nrFrontals = gc->nrFrontals();
@ -92,21 +100,36 @@ struct HybridGaussianConditional::Helper {
"Provided conditionals do not contain any frontal variables."); "Provided conditionals do not contain any frontal variables.");
} }
} }
/// Construct from tree of factor/scalar pairs.
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
auto func = [this](const GaussianFactorValuePair &pair) {
if (!pair.first) return;
auto gc = checkConditional(pair.first);
if (!nrFrontals) nrFrontals = gc->nrFrontals();
minNegLogConstant = std::min(minNegLogConstant, pair.second);
};
pairs.visit(func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable. "
"Provided conditionals do not contain any frontal variables.");
}
}
}; };
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const Helper &helper) const DiscreteKeys &discreteParents, Helper &&helper)
: BaseFactor(discreteParents, : BaseFactor(discreteParents,
FactorValuePairs(helper.pairs, FactorValuePairs(
[&](const GaussianFactorValuePair & [&](const GaussianFactorValuePair
pair) { // subtract minNegLogConstant &pair) { // subtract minNegLogConstant
return GaussianFactorValuePair{ return GaussianFactorValuePair{
pair.first, pair.first, pair.second - helper.minNegLogConstant};
pair.second - helper.minNegLogConstant}; },
})), std::move(helper.pairs))),
BaseConditional(*helper.nrFrontals), BaseConditional(*helper.nrFrontals),
conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {} negLogConstant_(helper.minNegLogConstant) {}
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
@ -142,17 +165,23 @@ HybridGaussianConditional::HybridGaussianConditional(
const HybridGaussianConditional::Conditionals &conditionals) const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {} : HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
/* *******************************************************************************/ /* *******************************************************************************/
const HybridGaussianConditional::Conditionals & const HybridGaussianConditional::Conditionals
HybridGaussianConditional::conditionals() const { HybridGaussianConditional::conditionals() const {
return conditionals_; return Conditionals(factors(), [](auto &&pair) {
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const { size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0; size_t total = 0;
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { factors().visit([&total](auto &&node) {
if (node) total += 1; if (node.first) total += 1;
}); });
return total; return total;
} }
@ -160,14 +189,11 @@ size_t HybridGaussianConditional::nrComponents() const {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::choose( GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues &discreteValues) const { const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues); auto &[factor, _] = factors()(discreteValues);
if (!ptr) return nullptr; if (!factor) return nullptr;
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional) auto conditional = checkConditional(factor);
return conditional; return conditional;
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -176,18 +202,16 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false; if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_ // Factors existence and scalar values are checked in BaseFactor::equals.
// is empty, but not if both are empty or both are not empty: // Here we check additionally that the factors *are* conditionals
if (conditionals_.empty() ^ e->conditionals_.empty()) return false; // and are equal.
auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
// Check the base and the factors: const GaussianFactorValuePair &pair2) {
return BaseFactor::equals(*e, tol) && auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
conditionals_.equals(e->conditionals_, c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
[tol](const GaussianConditional::shared_ptr &f1, return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
const GaussianConditional::shared_ptr &f2) { };
return (!f1 && !f2) || return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
(f1 && f2 && f1->equals(*f2, tol));
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -202,11 +226,12 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << std::endl std::cout << std::endl
<< " logNormalizationConstant: " << -negLogConstant() << std::endl << " logNormalizationConstant: " << -negLogConstant() << std::endl
<< std::endl; << std::endl;
conditionals_.print( factors().print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string { [&](const GaussianFactorValuePair &pair) -> std::string {
RedirectCout rd; RedirectCout rd;
if (gf && !gf->empty()) { if (auto gf =
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
gf->print("", formatter); gf->print("", formatter);
return rd.str(); return rd.str();
} else { } else {
@ -254,12 +279,16 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::FactorValuePairs likelihoods( const HybridGaussianFactor::FactorValuePairs likelihoods(
conditionals_, factors(),
[&](const GaussianConditional::shared_ptr &conditional) [&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
-> GaussianFactorValuePair { if (auto conditional =
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
const auto likelihood_m = conditional->likelihood(given); const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_; // pair.second == conditional->negLogConstant() - negLogConstant_
return {likelihood_m, Cgm_Kgcm}; return {likelihood_m, pair.second};
} else {
return {nullptr, std::numeric_limits<double>::infinity()};
}
}); });
return std::make_shared<HybridGaussianFactor>(discreteParentKeys, return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods); likelihoods);
@ -288,27 +317,32 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
// Check the max value for every combination of our keys. // Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional. // If the max value is 0.0, we can prune the corresponding conditional.
auto pruner = [&](const Assignment<Key> &choices, auto pruner =
const GaussianConditional::shared_ptr &conditional) [&](const Assignment<Key> &choices,
-> GaussianConditional::shared_ptr { const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
return (max->evaluate(choices) == 0.0) ? nullptr : conditional; if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
}; };
auto pruned_conditionals = conditionals_.apply(pruner); FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(), return std::make_shared<HybridGaussianConditional>(discreteKeys(),
pruned_conditionals); prunedConditionals);
} }
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianConditional::logProbability( double HybridGaussianConditional::logProbability(
const HybridValues &values) const { const HybridValues &values) const {
auto conditional = conditionals_(values.discrete()); auto [factor, _] = factors()(values.discrete());
auto conditional = checkConditional(factor);
return conditional->logProbability(values.continuous()); return conditional->logProbability(values.continuous());
} }
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues &values) const { double HybridGaussianConditional::evaluate(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete()); auto [factor, _] = factors()(values.discrete());
auto conditional = checkConditional(factor);
return conditional->evaluate(values.continuous()); return conditional->evaluate(values.continuous());
} }

View File

@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private: private:
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))). ///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization ///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_; double negLogConstant_;
@ -143,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Conditionals &conditionals); const Conditionals &conditionals);
/**
* @brief Construct from multiple discrete keys M and a tree of
* factor/scalar pairs, where the scalar is assumed to be the
* the negative log constant for each assignment m, up to a constant.
*
* @note Will throw if factors are not actually conditionals.
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const FactorValuePairs &pairs);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -192,8 +203,9 @@ class GTSAM_EXPORT HybridGaussianConditional
std::shared_ptr<HybridGaussianFactor> likelihood( std::shared_ptr<HybridGaussianFactor> likelihood(
const VectorValues &given) const; const VectorValues &given) const;
/// Getter for the underlying Conditionals DecisionTree /// Get Conditionals DecisionTree (dynamic cast from factors)
const Conditionals &conditionals() const; /// @note Slow: avoid using in favor of factors(), which uses existing tree.
const Conditionals conditionals() const;
/** /**
* @brief Compute the logProbability of this hybrid Gaussian conditional. * @brief Compute the logProbability of this hybrid Gaussian conditional.
@ -229,7 +241,7 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Private constructor that uses helper struct above. /// Private constructor that uses helper struct above.
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper); Helper &&helper);
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;
@ -241,7 +253,6 @@ class GTSAM_EXPORT HybridGaussianConditional
void serialize(Archive &ar, const unsigned int /*version*/) { void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
} }
#endif #endif
}; };

View File

@ -20,6 +20,7 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
@ -48,8 +49,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -57,10 +56,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
using std::dynamic_pointer_cast; using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>; using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result =
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>; /// Result from elimination.
using ResultValuePair = std::pair<Result, double>; struct Result {
using ResultTree = DecisionTree<Key, ResultValuePair>; GaussianConditional::shared_ptr conditional;
double negLogK;
GaussianFactor::shared_ptr factor;
double scalar;
bool operator==(const Result &other) const {
return conditional == other.conditional && negLogK == other.negLogK &&
factor == other.factor && scalar == other.scalar;
}
};
using ResultTree = DecisionTree<Key, Result>;
static const VectorValues kEmpty; static const VectorValues kEmpty;
@ -294,17 +303,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
static std::shared_ptr<Factor> createDiscreteFactor( static std::shared_ptr<Factor> createDiscreteFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
auto calculateError = [&](const auto &pair) -> double { auto calculateError = [&](const Result &result) -> double {
const auto &[conditional, factor] = pair.first; if (result.conditional && result.factor) {
const double scalar = pair.second;
if (conditional && factor) {
// `error` has the following contributions: // `error` has the following contributions:
// - the scalar is the sum of all mode-dependent constants // - the scalar is the sum of all mode-dependent constants
// - factor->error(kempty) is the error remaining after elimination // - factor->error(kempty) is the error remaining after elimination
// - negLogK is what is given to the conditional to normalize // - negLogK is what is given to the conditional to normalize
const double negLogK = conditional->negLogConstant(); return result.scalar + result.factor->error(kEmpty) - result.negLogK;
return scalar + factor->error(kEmpty) - negLogK; } else if (!result.conditional && !result.factor) {
} else if (!conditional && !factor) {
// If the factor has been pruned, return infinite error // If the factor has been pruned, return infinite error
return std::numeric_limits<double>::infinity(); return std::numeric_limits<double>::infinity();
} else { } else {
@ -323,13 +329,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional // Correct for the normalization constant used up by the conditional
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair { auto correct = [&](const Result &result) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair.first; if (result.conditional && result.factor) {
const double scalar = pair.second; return {result.factor, result.scalar - result.negLogK};
if (conditional && factor) { } else if (!result.conditional && !result.factor) {
const double negLogK = conditional->negLogConstant();
return {factor, scalar - negLogK};
} else if (!conditional && !factor) {
return {nullptr, std::numeric_limits<double>::infinity()}; return {nullptr, std::numeric_limits<double>::infinity()};
} else { } else {
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
@ -363,34 +366,34 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// any difference in noise models used. // any difference in noise models used.
HybridGaussianProductFactor productFactor = collectProductFactor(); HybridGaussianProductFactor productFactor = collectProductFactor();
// Convert factor graphs with a nullptr to an empty factor graph. // Check if a factor is null
// This is done after assembly since it is non-trivial to keep track of which auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
// FG has a nullptr as we're looping over the factors.
auto prunedProductFactor = productFactor.removeEmpty();
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
bool someContinuousLeft = false; bool someContinuousLeft = false;
auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair) auto eliminate =
-> std::pair<Result, double> { [&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
const auto &[graph, scalar] = pair; const auto &[graph, scalar] = pair;
if (graph.empty()) { // If any product contains a pruned factor, prune it here. Done here as it's
return {{nullptr, nullptr}, 0.0}; // non non-trivial to do within collectProductFactor.
if (graph.empty() || std::any_of(graph.begin(), graph.end(), isNull)) {
return {nullptr, 0.0, nullptr, 0.0};
} }
// Expensive elimination of product factor. // Expensive elimination of product factor.
auto result = auto [conditional, factor] =
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
// Record whether there any continuous variables left // Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty(); someContinuousLeft |= !factor->empty();
// We pass on the scalar unmodified. // We pass on the scalar unmodified.
return {result, scalar}; return {conditional, conditional->negLogConstant(), factor, scalar};
}; };
// Perform elimination! // Perform elimination!
ResultTree eliminationResults(prunedProductFactor, eliminate); const ResultTree eliminationResults(productFactor, eliminate);
// If there are no more continuous parents we create a DiscreteFactor with the // If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor // error for each discrete choice. Otherwise, create a HybridGaussianFactor
@ -400,12 +403,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
? createHybridGaussianFactor(eliminationResults, discreteSeparator) ? createHybridGaussianFactor(eliminationResults, discreteSeparator)
: createDiscreteFactor(eliminationResults, discreteSeparator); : createDiscreteFactor(eliminationResults, discreteSeparator);
// Create the HybridGaussianConditional from the conditionals // Create the HybridGaussianConditional without re-calculating constants:
HybridGaussianConditional::Conditionals conditionals( HybridGaussianConditional::FactorValuePairs pairs(
eliminationResults, eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
[](const ResultValuePair &pair) { return pair.first.first; }); return {result.conditional, result.negLogK};
auto hybridGaussian = std::make_shared<HybridGaussianConditional>( });
discreteSeparator, conditionals); auto hybridGaussian =
std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor}; return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
} }