commit
366b51432a
|
|
@ -29,6 +29,7 @@
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -286,6 +287,10 @@ namespace gtsam {
|
||||||
return branches_;
|
return branches_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<NodePtr>& branches() {
|
||||||
|
return branches_;
|
||||||
|
}
|
||||||
|
|
||||||
/** add a branch: TODO merge into constructor */
|
/** add a branch: TODO merge into constructor */
|
||||||
void push_back(NodePtr&& node) {
|
void push_back(NodePtr&& node) {
|
||||||
// allSame_ is restricted to leaf nodes in a decision tree
|
// allSame_ is restricted to leaf nodes in a decision tree
|
||||||
|
|
@ -483,7 +488,7 @@ namespace gtsam {
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {}
|
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
||||||
|
|
@ -554,6 +559,36 @@ namespace gtsam {
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
DecisionTree<L, Y>::DecisionTree(const Unary& op,
|
||||||
|
DecisionTree&& other) noexcept
|
||||||
|
: root_(std::move(other.root_)) {
|
||||||
|
// Apply the unary operation directly to each leaf in the tree
|
||||||
|
if (root_) {
|
||||||
|
// Define a helper function to traverse and apply the operation
|
||||||
|
struct ApplyUnary {
|
||||||
|
const Unary& op;
|
||||||
|
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
|
||||||
|
// Apply the unary operation to the leaf's constant value
|
||||||
|
leaf->constant_ = op(leaf->constant_);
|
||||||
|
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
||||||
|
// Recurse into the choice branches
|
||||||
|
for (NodePtr& branch : choice->branches()) {
|
||||||
|
(*this)(branch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ApplyUnary applyUnary{op};
|
||||||
|
applyUnary(root_);
|
||||||
|
}
|
||||||
|
// Reset the other tree's root to nullptr to avoid dangling references
|
||||||
|
other.root_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X, typename Func>
|
template <typename X, typename Func>
|
||||||
|
|
@ -694,7 +729,7 @@ namespace gtsam {
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) {
|
It begin, It end, ValueIt beginY, ValueIt endY) {
|
||||||
auto node = build(begin, end, beginY, endY);
|
auto node = build(begin, end, beginY, endY);
|
||||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
||||||
return Choice::Unique(choice);
|
return Choice::Unique(choice);
|
||||||
} else {
|
} else {
|
||||||
return node;
|
return node;
|
||||||
|
|
@ -710,7 +745,7 @@ namespace gtsam {
|
||||||
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) {
|
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -951,11 +986,16 @@ namespace gtsam {
|
||||||
return root_->equals(*other.root_);
|
return root_->equals(*other.root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
||||||
|
if (root_ == nullptr)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DecisionTree::operator() called on empty tree");
|
||||||
return root_->operator ()(x);
|
return root_->operator ()(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||||
// It is unclear what should happen if tree is empty:
|
// It is unclear what should happen if tree is empty:
|
||||||
|
|
@ -966,6 +1006,7 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
/// Apply unary operator with assignment
|
/// Apply unary operator with assignment
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
||||||
|
|
@ -1049,6 +1090,18 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename A, typename B>
|
||||||
|
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
|
||||||
|
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const {
|
||||||
|
using AB = std::pair<A, B>;
|
||||||
|
const DecisionTree<L, AB> ab(*this, AB_of_Y);
|
||||||
|
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
|
||||||
|
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
|
||||||
|
return {a, b};
|
||||||
|
}
|
||||||
|
|
||||||
/******************************************************************************/
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
struct Node {
|
struct Node {
|
||||||
using Ptr = std::shared_ptr<const Node>;
|
using Ptr = std::shared_ptr<Node>;
|
||||||
|
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
static int nrNodes;
|
static int nrNodes;
|
||||||
|
|
@ -156,10 +156,10 @@ namespace gtsam {
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
|
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||||
|
|
||||||
/** Internal helper function to create from
|
/**
|
||||||
* keys, cardinalities, and Y values.
|
* Internal helper function to create a tree from keys, cardinalities, and Y
|
||||||
* Calls `build` which builds thetree bottom-up,
|
* values. Calls `build` which builds the tree bottom-up, before we prune in
|
||||||
* before we prune in a top-down fashion.
|
* a top-down fashion.
|
||||||
*/
|
*/
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
|
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||||
|
|
@ -228,6 +228,15 @@ namespace gtsam {
|
||||||
DecisionTree(const L& label, const DecisionTree& f0,
|
DecisionTree(const L& label, const DecisionTree& f0,
|
||||||
const DecisionTree& f1);
|
const DecisionTree& f1);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Move constructor for DecisionTree. Very efficient as does not
|
||||||
|
* allocate anything, just changes in-place. But `other` is consumed.
|
||||||
|
*
|
||||||
|
* @param op The unary operation to apply to the moved DecisionTree.
|
||||||
|
* @param other The DecisionTree to move from, will be empty afterwards.
|
||||||
|
*/
|
||||||
|
DecisionTree(const Unary& op, DecisionTree&& other) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type.
|
* @brief Convert from a different value type.
|
||||||
*
|
*
|
||||||
|
|
@ -239,7 +248,7 @@ namespace gtsam {
|
||||||
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type X to value type Y, also transate
|
* @brief Convert from a different value type X to value type Y, also translate
|
||||||
* labels via map from type M to L.
|
* labels via map from type M to L.
|
||||||
*
|
*
|
||||||
* @tparam M Previous label type.
|
* @tparam M Previous label type.
|
||||||
|
|
@ -406,6 +415,18 @@ namespace gtsam {
|
||||||
const ValueFormatter& valueFormatter,
|
const ValueFormatter& valueFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert into two trees with value types A and B.
|
||||||
|
*
|
||||||
|
* @tparam A First new value type.
|
||||||
|
* @tparam B Second new value type.
|
||||||
|
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
|
||||||
|
* @return A pair of DecisionTrees with value types A and B respectively.
|
||||||
|
*/
|
||||||
|
template <typename A, typename B>
|
||||||
|
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
|
||||||
|
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const;
|
||||||
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDecisionTree.cpp
|
* @file testDecisionTree.cpp
|
||||||
* @brief Develop DecisionTree
|
* @brief DecisionTree unit tests
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @author Can Erdogan
|
* @author Can Erdogan
|
||||||
* @date Jan 30, 2012
|
* @date Jan 30, 2012
|
||||||
|
|
@ -108,6 +108,7 @@ struct DT : public DecisionTree<string, int> {
|
||||||
std::cout << s;
|
std::cout << s;
|
||||||
Base::print("", keyFormatter, valueFormatter);
|
Base::print("", keyFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Equality method customized to int node type
|
/// Equality method customized to int node type
|
||||||
bool equals(const Base& other, double tol = 1e-9) const {
|
bool equals(const Base& other, double tol = 1e-9) const {
|
||||||
auto compare = [](const int& v, const int& w) { return v == w; };
|
auto compare = [](const int& v, const int& w) { return v == w; };
|
||||||
|
|
@ -271,6 +272,58 @@ TEST(DecisionTree, Example) {
|
||||||
DOT(acnotb);
|
DOT(acnotb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test that we can create two trees out of one, using a function that returns a pair.
|
||||||
|
TEST(DecisionTree, Split) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
// Create a decision tree
|
||||||
|
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
|
||||||
|
|
||||||
|
// Define a function that returns an int/bool pair
|
||||||
|
auto split_function = [](const int& value) -> std::pair<int, bool> {
|
||||||
|
return {value*3, value*3 % 2 == 0};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Split the original tree into two new trees
|
||||||
|
auto [la,lb] = original.split<int,bool>(split_function);
|
||||||
|
|
||||||
|
// Check the first resulting tree
|
||||||
|
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
|
||||||
|
// Check the second resulting tree
|
||||||
|
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test that we can create a tree by modifying an rvalue.
|
||||||
|
TEST(DecisionTree, Consume) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
// Create a decision tree
|
||||||
|
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
|
||||||
|
|
||||||
|
DT modified([](int i){return i*2;}, std::move(original));
|
||||||
|
|
||||||
|
// Check the first resulting tree
|
||||||
|
EXPECT_LONGS_EQUAL(2, modified(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(4, modified(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(6, modified(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(8, modified(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
|
||||||
|
// Check original was moved
|
||||||
|
EXPECT(original.root_ == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of values
|
// test Conversion of values
|
||||||
bool bool_of_int(const int& y) { return y != 0; };
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,27 @@
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
GaussianConditional::shared_ptr checkConditional(
|
||||||
|
const GaussianFactor::shared_ptr &factor) {
|
||||||
|
if (auto conditional =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(factor)) {
|
||||||
|
return conditional;
|
||||||
|
} else {
|
||||||
|
throw std::logic_error(
|
||||||
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
/**
|
/**
|
||||||
* @brief Helper struct for constructing HybridGaussianConditional objects
|
* @brief Helper struct for constructing HybridGaussianConditional objects
|
||||||
|
|
@ -38,15 +53,13 @@ namespace gtsam {
|
||||||
* This struct contains the following fields:
|
* This struct contains the following fields:
|
||||||
* - nrFrontals: Optional size_t for number of frontal variables
|
* - nrFrontals: Optional size_t for number of frontal variables
|
||||||
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
||||||
* - conditionals: Conditionals for storing conditionals. TODO(frank): kill!
|
|
||||||
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
||||||
* constructor
|
* constructor
|
||||||
*/
|
*/
|
||||||
struct HybridGaussianConditional::Helper {
|
struct HybridGaussianConditional::Helper {
|
||||||
std::optional<size_t> nrFrontals;
|
|
||||||
FactorValuePairs pairs;
|
FactorValuePairs pairs;
|
||||||
Conditionals conditionals;
|
std::optional<size_t> nrFrontals = {};
|
||||||
double minNegLogConstant;
|
double minNegLogConstant = std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
using GC = GaussianConditional;
|
using GC = GaussianConditional;
|
||||||
using P = std::vector<std::pair<Vector, double>>;
|
using P = std::vector<std::pair<Vector, double>>;
|
||||||
|
|
@ -55,8 +68,6 @@ struct HybridGaussianConditional::Helper {
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
||||||
nrFrontals = 1;
|
nrFrontals = 1;
|
||||||
minNegLogConstant = std::numeric_limits<double>::infinity();
|
|
||||||
|
|
||||||
std::vector<GaussianFactorValuePair> fvs;
|
std::vector<GaussianFactorValuePair> fvs;
|
||||||
std::vector<GC::shared_ptr> gcs;
|
std::vector<GC::shared_ptr> gcs;
|
||||||
fvs.reserve(p.size());
|
fvs.reserve(p.size());
|
||||||
|
|
@ -70,14 +81,11 @@ struct HybridGaussianConditional::Helper {
|
||||||
gcs.push_back(gaussianConditional);
|
gcs.push_back(gaussianConditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
conditionals = Conditionals({mode}, gcs);
|
|
||||||
pairs = FactorValuePairs({mode}, fvs);
|
pairs = FactorValuePairs({mode}, fvs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from tree of GaussianConditionals.
|
/// Construct from tree of GaussianConditionals.
|
||||||
explicit Helper(const Conditionals &conditionals)
|
explicit Helper(const Conditionals &conditionals) {
|
||||||
: conditionals(conditionals),
|
|
||||||
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
|
||||||
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||||
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
|
|
@ -92,21 +100,36 @@ struct HybridGaussianConditional::Helper {
|
||||||
"Provided conditionals do not contain any frontal variables.");
|
"Provided conditionals do not contain any frontal variables.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct from tree of factor/scalar pairs.
|
||||||
|
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
||||||
|
auto func = [this](const GaussianFactorValuePair &pair) {
|
||||||
|
if (!pair.first) return;
|
||||||
|
auto gc = checkConditional(pair.first);
|
||||||
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
|
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
||||||
|
};
|
||||||
|
pairs.visit(func);
|
||||||
|
if (!nrFrontals.has_value()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridGaussianConditional: need at least one frontal variable. "
|
||||||
|
"Provided conditionals do not contain any frontal variables.");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKeys &discreteParents, const Helper &helper)
|
const DiscreteKeys &discreteParents, Helper &&helper)
|
||||||
: BaseFactor(discreteParents,
|
: BaseFactor(discreteParents,
|
||||||
FactorValuePairs(helper.pairs,
|
FactorValuePairs(
|
||||||
[&](const GaussianFactorValuePair &
|
[&](const GaussianFactorValuePair
|
||||||
pair) { // subtract minNegLogConstant
|
&pair) { // subtract minNegLogConstant
|
||||||
return GaussianFactorValuePair{
|
return GaussianFactorValuePair{
|
||||||
pair.first,
|
pair.first, pair.second - helper.minNegLogConstant};
|
||||||
pair.second - helper.minNegLogConstant};
|
},
|
||||||
})),
|
std::move(helper.pairs))),
|
||||||
BaseConditional(*helper.nrFrontals),
|
BaseConditional(*helper.nrFrontals),
|
||||||
conditionals_(helper.conditionals),
|
|
||||||
negLogConstant_(helper.minNegLogConstant) {}
|
negLogConstant_(helper.minNegLogConstant) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
|
|
@ -142,17 +165,23 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const HybridGaussianConditional::Conditionals &conditionals)
|
const HybridGaussianConditional::Conditionals &conditionals)
|
||||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||||
|
|
||||||
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
|
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
|
||||||
|
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const HybridGaussianConditional::Conditionals &
|
const HybridGaussianConditional::Conditionals
|
||||||
HybridGaussianConditional::conditionals() const {
|
HybridGaussianConditional::conditionals() const {
|
||||||
return conditionals_;
|
return Conditionals(factors(), [](auto &&pair) {
|
||||||
|
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
size_t HybridGaussianConditional::nrComponents() const {
|
size_t HybridGaussianConditional::nrComponents() const {
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
|
factors().visit([&total](auto &&node) {
|
||||||
if (node) total += 1;
|
if (node.first) total += 1;
|
||||||
});
|
});
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
@ -160,14 +189,11 @@ size_t HybridGaussianConditional::nrComponents() const {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues &discreteValues) const {
|
||||||
auto &ptr = conditionals_(discreteValues);
|
auto &[factor, _] = factors()(discreteValues);
|
||||||
if (!ptr) return nullptr;
|
if (!factor) return nullptr;
|
||||||
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
|
||||||
if (conditional)
|
auto conditional = checkConditional(factor);
|
||||||
return conditional;
|
return conditional;
|
||||||
else
|
|
||||||
throw std::logic_error(
|
|
||||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
@ -176,18 +202,16 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
if (e == nullptr) return false;
|
if (e == nullptr) return false;
|
||||||
|
|
||||||
// This will return false if either conditionals_ is empty or e->conditionals_
|
// Factors existence and scalar values are checked in BaseFactor::equals.
|
||||||
// is empty, but not if both are empty or both are not empty:
|
// Here we check additionally that the factors *are* conditionals
|
||||||
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
|
// and are equal.
|
||||||
|
auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
|
||||||
// Check the base and the factors:
|
const GaussianFactorValuePair &pair2) {
|
||||||
return BaseFactor::equals(*e, tol) &&
|
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
|
||||||
conditionals_.equals(e->conditionals_,
|
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
|
||||||
[tol](const GaussianConditional::shared_ptr &f1,
|
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
|
||||||
const GaussianConditional::shared_ptr &f2) {
|
};
|
||||||
return (!f1 && !f2) ||
|
return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
|
||||||
(f1 && f2 && f1->equals(*f2, tol));
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
@ -202,11 +226,12 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
std::cout << std::endl
|
std::cout << std::endl
|
||||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals_.print(
|
factors().print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
[&](const GaussianFactorValuePair &pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
if (gf && !gf->empty()) {
|
if (auto gf =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||||
gf->print("", formatter);
|
gf->print("", formatter);
|
||||||
return rd.str();
|
return rd.str();
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -254,12 +279,16 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||||
conditionals_,
|
factors(),
|
||||||
[&](const GaussianConditional::shared_ptr &conditional)
|
[&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
-> GaussianFactorValuePair {
|
if (auto conditional =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
|
// pair.second == conditional->negLogConstant() - negLogConstant_
|
||||||
return {likelihood_m, Cgm_Kgcm};
|
return {likelihood_m, pair.second};
|
||||||
|
} else {
|
||||||
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
|
}
|
||||||
});
|
});
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
||||||
likelihoods);
|
likelihoods);
|
||||||
|
|
@ -288,27 +317,32 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
|
|
||||||
// Check the max value for every combination of our keys.
|
// Check the max value for every combination of our keys.
|
||||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
auto pruner = [&](const Assignment<Key> &choices,
|
auto pruner =
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
[&](const Assignment<Key> &choices,
|
||||||
-> GaussianConditional::shared_ptr {
|
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
if (max->evaluate(choices) == 0.0)
|
||||||
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
|
else
|
||||||
|
return pair;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
pruned_conditionals);
|
prunedConditionals);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::logProbability(
|
double HybridGaussianConditional::logProbability(
|
||||||
const HybridValues &values) const {
|
const HybridValues &values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
|
auto conditional = checkConditional(factor);
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
|
auto conditional = checkConditional(factor);
|
||||||
return conditional->evaluate(values.continuous());
|
return conditional->evaluate(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
|
||||||
|
|
||||||
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
||||||
///< Take advantage of the neg-log space so everything is a minimization
|
///< Take advantage of the neg-log space so everything is a minimization
|
||||||
double negLogConstant_;
|
double negLogConstant_;
|
||||||
|
|
@ -143,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Conditionals &conditionals);
|
const Conditionals &conditionals);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from multiple discrete keys M and a tree of
|
||||||
|
* factor/scalar pairs, where the scalar is assumed to be the
|
||||||
|
* the negative log constant for each assignment m, up to a constant.
|
||||||
|
*
|
||||||
|
* @note Will throw if factors are not actually conditionals.
|
||||||
|
*
|
||||||
|
* @param discreteParents the discrete parents. Will be placed last.
|
||||||
|
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
|
||||||
|
*/
|
||||||
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
|
const FactorValuePairs &pairs);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -192,8 +203,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
std::shared_ptr<HybridGaussianFactor> likelihood(
|
std::shared_ptr<HybridGaussianFactor> likelihood(
|
||||||
const VectorValues &given) const;
|
const VectorValues &given) const;
|
||||||
|
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
||||||
const Conditionals &conditionals() const;
|
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
|
||||||
|
const Conditionals conditionals() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
||||||
|
|
@ -229,7 +241,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
|
|
||||||
/// Private constructor that uses helper struct above.
|
/// Private constructor that uses helper struct above.
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Helper &helper);
|
Helper &&helper);
|
||||||
|
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
@ -241,7 +253,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
void serialize(Archive &ar, const unsigned int /*version*/) {
|
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
|
@ -48,8 +49,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gtsam/discrete/DecisionTreeFactor.h"
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
|
|
@ -57,10 +56,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
using std::dynamic_pointer_cast;
|
using std::dynamic_pointer_cast;
|
||||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||||
using Result =
|
|
||||||
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
/// Result from elimination.
|
||||||
using ResultValuePair = std::pair<Result, double>;
|
struct Result {
|
||||||
using ResultTree = DecisionTree<Key, ResultValuePair>;
|
GaussianConditional::shared_ptr conditional;
|
||||||
|
double negLogK;
|
||||||
|
GaussianFactor::shared_ptr factor;
|
||||||
|
double scalar;
|
||||||
|
|
||||||
|
bool operator==(const Result &other) const {
|
||||||
|
return conditional == other.conditional && negLogK == other.negLogK &&
|
||||||
|
factor == other.factor && scalar == other.scalar;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
using ResultTree = DecisionTree<Key, Result>;
|
||||||
|
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
|
|
||||||
|
|
@ -294,17 +303,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
static std::shared_ptr<Factor> createDiscreteFactor(
|
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
const ResultTree &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
auto calculateError = [&](const auto &pair) -> double {
|
auto calculateError = [&](const Result &result) -> double {
|
||||||
const auto &[conditional, factor] = pair.first;
|
if (result.conditional && result.factor) {
|
||||||
const double scalar = pair.second;
|
|
||||||
if (conditional && factor) {
|
|
||||||
// `error` has the following contributions:
|
// `error` has the following contributions:
|
||||||
// - the scalar is the sum of all mode-dependent constants
|
// - the scalar is the sum of all mode-dependent constants
|
||||||
// - factor->error(kempty) is the error remaining after elimination
|
// - factor->error(kempty) is the error remaining after elimination
|
||||||
// - negLogK is what is given to the conditional to normalize
|
// - negLogK is what is given to the conditional to normalize
|
||||||
const double negLogK = conditional->negLogConstant();
|
return result.scalar + result.factor->error(kEmpty) - result.negLogK;
|
||||||
return scalar + factor->error(kEmpty) - negLogK;
|
} else if (!result.conditional && !result.factor) {
|
||||||
} else if (!conditional && !factor) {
|
|
||||||
// If the factor has been pruned, return infinite error
|
// If the factor has been pruned, return infinite error
|
||||||
return std::numeric_limits<double>::infinity();
|
return std::numeric_limits<double>::infinity();
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -323,13 +329,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
const ResultTree &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
// Correct for the normalization constant used up by the conditional
|
// Correct for the normalization constant used up by the conditional
|
||||||
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
|
auto correct = [&](const Result &result) -> GaussianFactorValuePair {
|
||||||
const auto &[conditional, factor] = pair.first;
|
if (result.conditional && result.factor) {
|
||||||
const double scalar = pair.second;
|
return {result.factor, result.scalar - result.negLogK};
|
||||||
if (conditional && factor) {
|
} else if (!result.conditional && !result.factor) {
|
||||||
const double negLogK = conditional->negLogConstant();
|
|
||||||
return {factor, scalar - negLogK};
|
|
||||||
} else if (!conditional && !factor) {
|
|
||||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
||||||
|
|
@ -363,34 +366,34 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
// any difference in noise models used.
|
// any difference in noise models used.
|
||||||
HybridGaussianProductFactor productFactor = collectProductFactor();
|
HybridGaussianProductFactor productFactor = collectProductFactor();
|
||||||
|
|
||||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
// Check if a factor is null
|
||||||
// This is done after assembly since it is non-trivial to keep track of which
|
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
|
||||||
// FG has a nullptr as we're looping over the factors.
|
|
||||||
auto prunedProductFactor = productFactor.removeEmpty();
|
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
bool someContinuousLeft = false;
|
bool someContinuousLeft = false;
|
||||||
auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
|
auto eliminate =
|
||||||
-> std::pair<Result, double> {
|
[&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
|
||||||
const auto &[graph, scalar] = pair;
|
const auto &[graph, scalar] = pair;
|
||||||
|
|
||||||
if (graph.empty()) {
|
// If any product contains a pruned factor, prune it here. Done here as it's
|
||||||
return {{nullptr, nullptr}, 0.0};
|
// non non-trivial to do within collectProductFactor.
|
||||||
|
if (graph.empty() || std::any_of(graph.begin(), graph.end(), isNull)) {
|
||||||
|
return {nullptr, 0.0, nullptr, 0.0};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expensive elimination of product factor.
|
// Expensive elimination of product factor.
|
||||||
auto result =
|
auto [conditional, factor] =
|
||||||
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
|
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
|
||||||
|
|
||||||
// Record whether there any continuous variables left
|
// Record whether there any continuous variables left
|
||||||
someContinuousLeft |= !result.second->empty();
|
someContinuousLeft |= !factor->empty();
|
||||||
|
|
||||||
// We pass on the scalar unmodified.
|
// We pass on the scalar unmodified.
|
||||||
return {result, scalar};
|
return {conditional, conditional->negLogConstant(), factor, scalar};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
ResultTree eliminationResults(prunedProductFactor, eliminate);
|
const ResultTree eliminationResults(productFactor, eliminate);
|
||||||
|
|
||||||
// If there are no more continuous parents we create a DiscreteFactor with the
|
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||||
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
||||||
|
|
@ -400,12 +403,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
||||||
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
||||||
|
|
||||||
// Create the HybridGaussianConditional from the conditionals
|
// Create the HybridGaussianConditional without re-calculating constants:
|
||||||
HybridGaussianConditional::Conditionals conditionals(
|
HybridGaussianConditional::FactorValuePairs pairs(
|
||||||
eliminationResults,
|
eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
|
||||||
[](const ResultValuePair &pair) { return pair.first.first; });
|
return {result.conditional, result.negLogK};
|
||||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
});
|
||||||
discreteSeparator, conditionals);
|
auto hybridGaussian =
|
||||||
|
std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue