Merge pull request #1575 from borglab/hybrid-tablefactor-2
commit
ba7c077a25
|
@ -93,7 +93,8 @@ namespace gtsam {
|
||||||
/// print
|
/// print
|
||||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter) const override {
|
const ValueFormatter& valueFormatter) const override {
|
||||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
std::cout << s << " Leaf [" << nrAssignments() << "] "
|
||||||
|
<< valueFormatter(constant_) << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Write graphviz format to stream `os`. */
|
/** Write graphviz format to stream `os`. */
|
||||||
|
@ -827,6 +828,16 @@ namespace gtsam {
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
size_t DecisionTree<L, Y>::nrAssignments() const {
|
||||||
|
size_t n = 0;
|
||||||
|
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
|
||||||
|
n += leaf.nrAssignments();
|
||||||
|
});
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
|
|
@ -320,6 +320,42 @@ namespace gtsam {
|
||||||
/// Return the number of leaves in the tree.
|
/// Return the number of leaves in the tree.
|
||||||
size_t nrLeaves() const;
|
size_t nrLeaves() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief This is a convenience function which returns the total number of
|
||||||
|
* leaf assignments in the decision tree.
|
||||||
|
* This function is not used for anymajor operations within the discrete
|
||||||
|
* factor graph framework.
|
||||||
|
*
|
||||||
|
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
|
||||||
|
* binary tree each leaf has 2 assignments. This includes counts removed
|
||||||
|
* from implicit pruning hence, it will always be >= nrLeaves().
|
||||||
|
*
|
||||||
|
* E.g. we have a decision tree as below, where each node has 2 branches:
|
||||||
|
*
|
||||||
|
* Choice(m1)
|
||||||
|
* 0 Choice(m0)
|
||||||
|
* 0 0 Leaf 0.0
|
||||||
|
* 0 1 Leaf 0.0
|
||||||
|
* 1 Choice(m0)
|
||||||
|
* 1 0 Leaf 1.0
|
||||||
|
* 1 1 Leaf 2.0
|
||||||
|
*
|
||||||
|
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
|
||||||
|
* and 4 leaves.
|
||||||
|
*
|
||||||
|
* In the pruned form, the number of assignments is still 4 but the number
|
||||||
|
* of leaves is now 3, as below:
|
||||||
|
*
|
||||||
|
* Choice(m1)
|
||||||
|
* 0 Leaf 0.0
|
||||||
|
* 1 Choice(m0)
|
||||||
|
* 1 0 Leaf 1.0
|
||||||
|
* 1 1 Leaf 2.0
|
||||||
|
*
|
||||||
|
* @return size_t
|
||||||
|
*/
|
||||||
|
size_t nrAssignments() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Fold a binary function over the tree, returning accumulator.
|
* @brief Fold a binary function over the tree, returning accumulator.
|
||||||
*
|
*
|
||||||
|
|
|
@ -101,6 +101,14 @@ namespace gtsam {
|
||||||
return DecisionTreeFactor(keys, result);
|
return DecisionTreeFactor(keys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
|
||||||
|
// apply operand
|
||||||
|
ADT result = ADT::apply(op);
|
||||||
|
// Make a new factor
|
||||||
|
return DecisionTreeFactor(discreteKeys(), result);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
size_t nrFrontals, ADT::Binary op) const {
|
size_t nrFrontals, ADT::Binary op) const {
|
||||||
|
|
|
@ -182,6 +182,12 @@ namespace gtsam {
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply unary operator (*this) "op" f
|
||||||
|
* @param op a unary operator that operates on AlgebraicDecisionTree
|
||||||
|
*/
|
||||||
|
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply binary operator (*this) "op" f
|
* Apply binary operator (*this) "op" f
|
||||||
* @param f the second argument for op
|
* @param f the second argument for op
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
|
|
|
@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
return Base::equals(bn, tol);
|
return Base::equals(bn, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|
||||||
AlgebraicDecisionTree<Key> discreteProbs;
|
|
||||||
|
|
||||||
// The canonical decision tree factor which will get
|
|
||||||
// the discrete conditionals added to it.
|
|
||||||
DecisionTreeFactor discreteProbsFactor;
|
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
|
||||||
if (conditional->isDiscrete()) {
|
|
||||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
|
||||||
DecisionTreeFactor f(*conditional->asDiscrete());
|
|
||||||
discreteProbsFactor = discreteProbsFactor * f;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functional.
|
* @brief Helper function to get the pruner functional.
|
||||||
|
@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||||
const DecisionTreeFactor &prunedDiscreteProbs) {
|
size_t maxNrLeaves) {
|
||||||
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
|
// Get the joint distribution of only the discrete keys
|
||||||
|
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
// The joint discrete probability.
|
||||||
|
DiscreteConditional discreteProbs;
|
||||||
|
|
||||||
|
std::vector<size_t> discrete_factor_idxs;
|
||||||
|
// Record frontal keys so we can maintain ordering
|
||||||
|
Ordering discrete_frontals;
|
||||||
|
|
||||||
// Loop with index since we need it later.
|
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
auto conditional = this->at(i);
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
auto discrete = conditional->asDiscrete();
|
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||||
|
|
||||||
// Convert pointer from conditional to factor
|
Ordering conditional_keys(conditional->frontals());
|
||||||
auto discreteTree =
|
discrete_frontals += conditional_keys;
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
discrete_factor_idxs.push_back(i);
|
||||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
|
||||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
|
||||||
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
|
||||||
|
|
||||||
gttic_(HybridBayesNet_MakeConditional);
|
|
||||||
// Create the new (hybrid) conditional
|
|
||||||
KeyVector frontals(discrete->frontals().begin(),
|
|
||||||
discrete->frontals().end());
|
|
||||||
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
|
||||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
|
||||||
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
|
||||||
gttoc_(HybridBayesNet_MakeConditional);
|
|
||||||
|
|
||||||
// Add it back to the BayesNet
|
|
||||||
this->at(i) = conditional;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const DecisionTreeFactor prunedDiscreteProbs =
|
||||||
|
discreteProbs.prune(maxNrLeaves);
|
||||||
|
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||||
|
|
||||||
|
// Eliminate joint probability back into conditionals
|
||||||
|
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||||
|
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||||
|
|
||||||
|
// Assign pruned discrete conditionals back at the correct indices.
|
||||||
|
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||||
|
size_t idx = discrete_factor_idxs.at(i);
|
||||||
|
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||||
|
}
|
||||||
|
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||||
|
|
||||||
|
return prunedDiscreteProbs;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Get the decision tree of only the discrete keys
|
DecisionTreeFactor prunedDiscreteProbs =
|
||||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||||
DecisionTreeFactor::shared_ptr discreteConditionals =
|
|
||||||
this->discreteConditionals();
|
|
||||||
const DecisionTreeFactor prunedDiscreteProbs =
|
|
||||||
discreteConditionals->prune(maxNrLeaves);
|
|
||||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
|
||||||
|
|
||||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
/* To prune, we visitWith every leaf in the GaussianMixture.
|
||||||
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
|
||||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
|
||||||
|
|
||||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||||
*
|
*
|
||||||
|
|
|
@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
|
||||||
*
|
|
||||||
* @return DecisionTreeFactor::shared_ptr
|
|
||||||
*/
|
|
||||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Sample from an incomplete BayesNet, given missing variables.
|
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||||
*
|
*
|
||||||
|
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Update the discrete conditionals with the pruned versions.
|
* @brief Prune all the discrete conditionals.
|
||||||
*
|
*
|
||||||
* @param prunedDiscreteProbs
|
* @param maxNrLeaves
|
||||||
*/
|
*/
|
||||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
|
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
* @date January, 2023
|
* @date January, 2023
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -26,7 +25,7 @@ namespace gtsam {
|
||||||
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
||||||
std::set<DiscreteKey> keys;
|
std::set<DiscreteKey> keys;
|
||||||
for (auto& factor : factors_) {
|
for (auto& factor : factors_) {
|
||||||
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
for (const DiscreteKey& key : p->discreteKeys()) {
|
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||||
keys.insert(key);
|
keys.insert(key);
|
||||||
}
|
}
|
||||||
|
@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
|
||||||
for (const Key& key : p->continuousKeys()) {
|
for (const Key& key : p->continuousKeys()) {
|
||||||
keys.insert(key);
|
keys.insert(key);
|
||||||
}
|
}
|
||||||
|
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
keys.insert(p->keys().begin(), p->keys().end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return keys;
|
return keys;
|
||||||
|
|
|
@ -48,8 +48,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// #define HYBRID_TIMING
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
|
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
// TODO(dellaert): in C++20, we can use std::visit.
|
// TODO(dellaert): in C++20, we can use std::visit.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// Don't do anything for discrete-only factors
|
// Don't do anything for discrete-only factors
|
||||||
// since we want to eliminate continuous values only.
|
// since we want to eliminate continuous values only.
|
||||||
continue;
|
continue;
|
||||||
|
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
dfg.push_back(dtf);
|
dfg.push_back(df);
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
// TODO(dellaert): is this correct? If so explain here.
|
// TODO(dellaert): is this correct? If so explain here.
|
||||||
|
@ -262,6 +260,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
};
|
};
|
||||||
|
|
||||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
std::make_shared<HybridConditional>(gaussianMixture),
|
std::make_shared<HybridConditional>(gaussianMixture),
|
||||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
||||||
|
@ -348,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
// When the number of assignments is large we may encounter stack overflows.
|
// When the number of assignments is large we may encounter stack overflows.
|
||||||
// However this is also the case with iSAM2, so no pressure :)
|
// However this is also the case with iSAM2, so no pressure :)
|
||||||
|
|
||||||
// PREPROCESS: Identify the nature of the current elimination
|
// Check the factors:
|
||||||
|
|
||||||
// TODO(dellaert): just check the factors:
|
|
||||||
// 1. if all factors are discrete, then we can do discrete elimination:
|
// 1. if all factors are discrete, then we can do discrete elimination:
|
||||||
// 2. if all factors are continuous, then we can do continuous elimination:
|
// 2. if all factors are continuous, then we can do continuous elimination:
|
||||||
// 3. if not, we do hybrid elimination:
|
// 3. if not, we do hybrid elimination:
|
||||||
|
|
||||||
// First, identify the separator keys, i.e. all keys that are not frontal.
|
bool only_discrete = true, only_continuous = true;
|
||||||
KeySet separatorKeys;
|
|
||||||
for (auto &&factor : factors) {
|
for (auto &&factor : factors) {
|
||||||
separatorKeys.insert(factor->begin(), factor->end());
|
if (auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
|
if (hybrid_factor->isDiscrete()) {
|
||||||
|
only_continuous = false;
|
||||||
|
} else if (hybrid_factor->isContinuous()) {
|
||||||
|
only_discrete = false;
|
||||||
|
} else if (hybrid_factor->isHybrid()) {
|
||||||
|
only_continuous = false;
|
||||||
|
only_discrete = false;
|
||||||
}
|
}
|
||||||
// remove frontals from separator
|
} else if (auto cont_factor =
|
||||||
for (auto &k : frontalKeys) {
|
std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
separatorKeys.erase(k);
|
only_discrete = false;
|
||||||
}
|
} else if (auto discrete_factor =
|
||||||
|
std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
// Build a map from keys to DiscreteKeys
|
only_continuous = false;
|
||||||
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
|
|
||||||
|
|
||||||
// Fill in discrete frontals and continuous frontals.
|
|
||||||
std::set<DiscreteKey> discreteFrontals;
|
|
||||||
KeySet continuousFrontals;
|
|
||||||
for (auto &k : frontalKeys) {
|
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
|
||||||
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
|
||||||
} else {
|
|
||||||
continuousFrontals.insert(k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in discrete discrete separator keys and continuous separator keys.
|
|
||||||
std::set<DiscreteKey> discreteSeparatorSet;
|
|
||||||
KeyVector continuousSeparator;
|
|
||||||
for (auto &k : separatorKeys) {
|
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
|
||||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
|
||||||
} else {
|
|
||||||
continuousSeparator.push_back(k);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have any continuous keys:
|
|
||||||
const bool discrete_only =
|
|
||||||
continuousFrontals.empty() && continuousSeparator.empty();
|
|
||||||
|
|
||||||
// NOTE: We should really defer the product here because of pruning
|
// NOTE: We should really defer the product here because of pruning
|
||||||
|
|
||||||
if (discrete_only) {
|
if (only_discrete) {
|
||||||
// Case 1: we are only dealing with discrete
|
// Case 1: we are only dealing with discrete
|
||||||
return discreteElimination(factors, frontalKeys);
|
return discreteElimination(factors, frontalKeys);
|
||||||
} else if (mapFromKeyToDiscreteKey.empty()) {
|
} else if (only_continuous) {
|
||||||
// Case 2: we are only dealing with continuous
|
// Case 2: we are only dealing with continuous
|
||||||
return continuousElimination(factors, frontalKeys);
|
return continuousElimination(factors, frontalKeys);
|
||||||
} else {
|
} else {
|
||||||
// Case 3: We are now in the hybrid land!
|
// Case 3: We are now in the hybrid land!
|
||||||
|
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
|
||||||
|
|
||||||
|
// Find all the keys in the set of continuous keys
|
||||||
|
// which are not in the frontal keys. This is our continuous separator.
|
||||||
|
KeyVector continuousSeparator;
|
||||||
|
auto continuousKeySet = factors.continuousKeySet();
|
||||||
|
std::set_difference(
|
||||||
|
continuousKeySet.begin(), continuousKeySet.end(),
|
||||||
|
frontalKeysSet.begin(), frontalKeysSet.end(),
|
||||||
|
std::inserter(continuousSeparator, continuousSeparator.begin()));
|
||||||
|
|
||||||
|
// Similarly for the discrete separator.
|
||||||
|
KeySet discreteSeparatorSet;
|
||||||
|
std::set<DiscreteKey> discreteSeparator;
|
||||||
|
auto discreteKeySet = factors.discreteKeySet();
|
||||||
|
std::set_difference(
|
||||||
|
discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
|
||||||
|
frontalKeysSet.end(),
|
||||||
|
std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
|
||||||
|
// Convert from set of keys to set of DiscreteKeys
|
||||||
|
auto discreteKeyMap = factors.discreteKeyMap();
|
||||||
|
for (auto key : discreteSeparatorSet) {
|
||||||
|
discreteSeparator.insert(discreteKeyMap.at(key));
|
||||||
|
}
|
||||||
|
|
||||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||||
discreteSeparatorSet);
|
discreteSeparator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -429,7 +432,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
// Add the gaussian factor error to every leaf of the error tree.
|
// Add the gaussian factor error to every leaf of the error tree.
|
||||||
error_tree = error_tree.apply(
|
error_tree = error_tree.apply(
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// If factor at `idx` is discrete-only, we skip.
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -40,6 +40,7 @@ class HybridEliminationTree;
|
||||||
class HybridBayesTree;
|
class HybridBayesTree;
|
||||||
class HybridJunctionTree;
|
class HybridJunctionTree;
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
|
class TableFactor;
|
||||||
class JacobianFactor;
|
class JacobianFactor;
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
|
||||||
for (auto& k : hf->discreteKeys()) {
|
for (auto& k : hf->discreteKeys()) {
|
||||||
data.discreteKeys.insert(k.first);
|
data.discreteKeys.insert(k.first);
|
||||||
}
|
}
|
||||||
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
for (auto& k : hf->discreteKeys()) {
|
for (auto& k : hf->discreteKeys()) {
|
||||||
data.discreteKeys.insert(k.first);
|
data.discreteKeys.insert(k.first);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
|
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||||
linearFG->push_back(gf);
|
linearFG->push_back(gf);
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// If discrete-only: doesn't need linearization.
|
// If discrete-only: doesn't need linearization.
|
||||||
linearFG->push_back(f);
|
linearFG->push_back(f);
|
||||||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
|
|
@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||||
addConditionals(graph, hybridBayesNet_, ordering);
|
addConditionals(graph, hybridBayesNet_, ordering);
|
||||||
|
|
||||||
// Eliminate.
|
// Eliminate.
|
||||||
auto bayesNetFragment = graph.eliminateSequential(ordering);
|
HybridBayesNet::shared_ptr bayesNetFragment =
|
||||||
|
graph.eliminateSequential(ordering);
|
||||||
|
|
||||||
/// Prune
|
/// Prune
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
|
@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
HybridGaussianFactorGraph graph(originalGraph);
|
HybridGaussianFactorGraph graph(originalGraph);
|
||||||
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
||||||
|
|
||||||
// If we are not at the first iteration, means we have conditionals to add.
|
// If hybridBayesNet is not empty,
|
||||||
|
// it means we have conditionals to add to the factor graph.
|
||||||
if (!hybridBayesNet.empty()) {
|
if (!hybridBayesNet.empty()) {
|
||||||
// We add all relevant conditional mixtures on the last continuous variable
|
// We add all relevant conditional mixtures on the last continuous variable
|
||||||
// in the previous `hybridBayesNet` to the graph
|
// in the previous `hybridBayesNet` to the graph
|
||||||
|
|
|
@ -202,31 +202,16 @@ struct Switching {
|
||||||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
||||||
* E.g. if K=4, we want M0, M1 and M2.
|
* E.g. if K=4, we want M0, M1 and M2.
|
||||||
*
|
*
|
||||||
* @param fg The nonlinear factor graph to which the mode chain is added.
|
* @param fg The factor graph to which the mode chain is added.
|
||||||
*/
|
*/
|
||||||
void addModeChain(HybridNonlinearFactorGraph *fg,
|
template <typename FACTORGRAPH>
|
||||||
|
void addModeChain(FACTORGRAPH *fg,
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
std::string discrete_transition_prob = "1/2 3/2") {
|
||||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
for (size_t k = 0; k < K - 2; k++) {
|
||||||
auto parents = {modes[k]};
|
auto parents = {modes[k]};
|
||||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
fg->template emplace_shared<DiscreteConditional>(
|
||||||
discrete_transition_prob);
|
modes[k + 1], parents, discrete_transition_prob);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2).
|
|
||||||
* E.g. if K=4, we want M0, M1 and M2.
|
|
||||||
*
|
|
||||||
* @param fg The gaussian factor graph to which the mode chain is added.
|
|
||||||
*/
|
|
||||||
void addModeChain(HybridGaussianFactorGraph *fg,
|
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
|
||||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
|
||||||
auto parents = {modes[k]};
|
|
||||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
|
||||||
discrete_transition_prob);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
std::string expected =
|
std::string expected =
|
||||||
R"(Hybrid [x1 x2; 1]{
|
R"(Hybrid [x1 x2; 1]{
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf :
|
0 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
0;
|
0;
|
||||||
0
|
0
|
||||||
|
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
||||||
b = [ 0 0 ]
|
b = [ 0 0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
0;
|
0;
|
||||||
0
|
0
|
||||||
|
|
|
@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||||
|
|
||||||
// Regression test on pruned logProbability tree
|
// Regression test on pruned logProbability tree
|
||||||
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
|
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||||
|
|
||||||
|
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
logProbability +=
|
logProbability +=
|
||||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
|
||||||
|
// Regression
|
||||||
double density = exp(logProbability);
|
double density = exp(logProbability);
|
||||||
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(density,
|
||||||
|
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||||
1e-9);
|
1e-9);
|
||||||
|
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||||
|
|
||||||
size_t maxNrLeaves = 3;
|
size_t maxNrLeaves = 3;
|
||||||
auto discreteConditionals = posterior->discreteConditionals();
|
DiscreteConditional discreteConditionals;
|
||||||
|
for (auto&& conditional : *posterior) {
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
discreteConditionals =
|
||||||
|
discreteConditionals * (*conditional->asDiscrete());
|
||||||
|
}
|
||||||
|
}
|
||||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||||
std::make_shared<DecisionTreeFactor>(
|
std::make_shared<DecisionTreeFactor>(
|
||||||
discreteConditionals->prune(maxNrLeaves));
|
discreteConditionals.prune(maxNrLeaves));
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
prunedDecisionTree->nrLeaves());
|
prunedDecisionTree->nrLeaves());
|
||||||
|
|
||||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
// regression
|
||||||
|
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||||
|
DecisionTreeFactor::ADT potentials(
|
||||||
|
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||||
|
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||||
|
|
||||||
// Prune!
|
// Prune!
|
||||||
posterior->prune(maxNrLeaves);
|
posterior->prune(maxNrLeaves);
|
||||||
|
|
||||||
// Functor to verify values against the original_discrete_conditionals
|
// Functor to verify values against the expected_discrete_conditionals
|
||||||
auto checker = [&](const Assignment<Key>& assignment,
|
auto checker = [&](const Assignment<Key>& assignment,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
|
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
if (prunedDecisionTree->operator()(choices) == 0) {
|
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
||||||
1e-9);
|
1e-9);
|
||||||
}
|
}
|
||||||
return 0.0;
|
return 0.0;
|
||||||
|
|
|
@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
for (auto&& f : *remainingFactorGraph) {
|
for (auto&& f : *remainingFactorGraph) {
|
||||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
|
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
|
||||||
assert(discreteFactor);
|
assert(discreteFactor);
|
||||||
dfg.push_back(discreteFactor);
|
dfg.push_back(discreteFactor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) {
|
||||||
EXPECT(assert_equal(expected_continuous, result));
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// Test approximate inference with an additional pruning step.
|
||||||
|
TEST(HybridEstimation, ISAM) {
|
||||||
|
size_t K = 15;
|
||||||
|
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
||||||
|
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
||||||
|
// Ground truth discrete seq
|
||||||
|
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
||||||
|
// Switching example of robot moving in 1D
|
||||||
|
// with given measurements and equal mode priors.
|
||||||
|
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||||
|
HybridNonlinearISAM isam;
|
||||||
|
HybridNonlinearFactorGraph graph;
|
||||||
|
Values initial;
|
||||||
|
|
||||||
|
// gttic_(Estimation);
|
||||||
|
|
||||||
|
// Add the X(0) prior
|
||||||
|
graph.push_back(switching.nonlinearFactorGraph.at(0));
|
||||||
|
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph linearized;
|
||||||
|
|
||||||
|
for (size_t k = 1; k < K; k++) {
|
||||||
|
// Motion Model
|
||||||
|
graph.push_back(switching.nonlinearFactorGraph.at(k));
|
||||||
|
// Measurement
|
||||||
|
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
|
||||||
|
|
||||||
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
|
isam.update(graph, initial, 3);
|
||||||
|
// isam.bayesTree().print("\n\n");
|
||||||
|
|
||||||
|
graph.resize(0);
|
||||||
|
initial.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
Values result = isam.estimate();
|
||||||
|
DiscreteValues assignment = isam.assignment();
|
||||||
|
|
||||||
|
DiscreteValues expected_discrete;
|
||||||
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_discrete, assignment));
|
||||||
|
|
||||||
|
Values expected_continuous;
|
||||||
|
for (size_t k = 0; k < K; k++) {
|
||||||
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A function to get a specific 1D robot motion problem as a linearized
|
* @brief A function to get a specific 1D robot motion problem as a linearized
|
||||||
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
|
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
|
||||||
|
|
|
@ -18,7 +18,9 @@
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
#include <gtsam/nonlinear/PriorFactor.h>
|
#include <gtsam/nonlinear/PriorFactor.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
|
||||||
HybridFactorGraph fg;
|
HybridFactorGraph fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test if methods to get keys work as expected.
|
||||||
|
TEST(HybridFactorGraph, Keys) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
// Add prior on x0
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
// Add factor between x0 and x1
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
// Add a gaussian mixture factor ϕ(x1, c1)
|
||||||
|
DiscreteKey m1(M(1), 2);
|
||||||
|
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
|
||||||
|
M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
|
||||||
|
std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
|
||||||
|
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
|
||||||
|
|
||||||
|
KeySet expected_continuous{X(0), X(1)};
|
||||||
|
EXPECT(
|
||||||
|
assert_container_equality(expected_continuous, hfg.continuousKeySet()));
|
||||||
|
|
||||||
|
KeySet expected_discrete{M(1)};
|
||||||
|
EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
||||||
// Test resulting posterior Bayes net has correct size:
|
// Test resulting posterior Bayes net has correct size:
|
||||||
EXPECT_LONGS_EQUAL(8, posterior->size());
|
EXPECT_LONGS_EQUAL(8, posterior->size());
|
||||||
|
|
||||||
// TODO(dellaert): this test fails - no idea why.
|
// Ratio test
|
||||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -492,7 +492,7 @@ factor 0:
|
||||||
factor 1:
|
factor 1:
|
||||||
Hybrid [x0 x1; m0]{
|
Hybrid [x0 x1; m0]{
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf :
|
0 Leaf [1] :
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1] :
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{
|
||||||
factor 2:
|
factor 2:
|
||||||
Hybrid [x1 x2; m1]{
|
Hybrid [x1 x2; m1]{
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Leaf :
|
0 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf [1] :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
-1
|
-1
|
||||||
]
|
]
|
||||||
|
@ -550,16 +550,16 @@ factor 4:
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 5: P( m0 ):
|
factor 5: P( m0 ):
|
||||||
Leaf 0.5
|
Leaf [2] 0.5
|
||||||
|
|
||||||
factor 6: P( m1 | m0 ):
|
factor 6: P( m1 | m0 ):
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf 0.33333333
|
0 0 Leaf [1] 0.33333333
|
||||||
0 1 Leaf 0.6
|
0 1 Leaf [1] 0.6
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf 0.66666667
|
1 0 Leaf [1] 0.66666667
|
||||||
1 1 Leaf 0.4
|
1 1 Leaf [1] 0.4
|
||||||
|
|
||||||
)";
|
)";
|
||||||
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
||||||
|
@ -570,13 +570,13 @@ size: 3
|
||||||
conditional 0: Hybrid P( x0 | x1 m0)
|
conditional 0: Hybrid P( x0 | x1 m0)
|
||||||
Discrete Keys = (m0, 2),
|
Discrete Keys = (m0, 2),
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf p(x0 | x1)
|
0 Leaf [1] p(x0 | x1)
|
||||||
R = [ 10.0499 ]
|
R = [ 10.0499 ]
|
||||||
S[x1] = [ -0.0995037 ]
|
S[x1] = [ -0.0995037 ]
|
||||||
d = [ -9.85087 ]
|
d = [ -9.85087 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Leaf p(x0 | x1)
|
1 Leaf [1] p(x0 | x1)
|
||||||
R = [ 10.0499 ]
|
R = [ 10.0499 ]
|
||||||
S[x1] = [ -0.0995037 ]
|
S[x1] = [ -0.0995037 ]
|
||||||
d = [ -9.95037 ]
|
d = [ -9.95037 ]
|
||||||
|
@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf p(x1 | x2)
|
0 0 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -9.99901 ]
|
d = [ -9.99901 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
0 1 Leaf p(x1 | x2)
|
0 1 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -9.90098 ]
|
d = [ -9.90098 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf p(x1 | x2)
|
1 0 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -10.098 ]
|
d = [ -10.098 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 1 Leaf p(x1 | x2)
|
1 1 Leaf [1] p(x1 | x2)
|
||||||
R = [ 10.099 ]
|
R = [ 10.099 ]
|
||||||
S[x2] = [ -0.0990196 ]
|
S[x2] = [ -0.0990196 ]
|
||||||
d = [ -10 ]
|
d = [ -10 ]
|
||||||
|
@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf p(x2)
|
0 0 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.1489 ]
|
d = [ -10.1489 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
x2: -1.0099
|
x2: -1.0099
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
0 1 Leaf p(x2)
|
0 1 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.1479 ]
|
d = [ -10.1479 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
|
@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 Choice(m0)
|
1 Choice(m0)
|
||||||
1 0 Leaf p(x2)
|
1 0 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.0504 ]
|
d = [ -10.0504 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
x2: -1.0001
|
x2: -1.0001
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
1 1 Leaf p(x2)
|
1 1 Leaf [1] p(x2)
|
||||||
R = [ 10.0494 ]
|
R = [ 10.0494 ]
|
||||||
d = [ -10.0494 ]
|
d = [ -10.0494 ]
|
||||||
mean: 1 elements
|
mean: 1 elements
|
||||||
|
|
|
@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
|
||||||
R"(Hybrid [x1 x2; 1]
|
R"(Hybrid [x1 x2; 1]
|
||||||
MixtureFactor
|
MixtureFactor
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf Nonlinear factor on 2 keys
|
0 Leaf [1] Nonlinear factor on 2 keys
|
||||||
1 Leaf Nonlinear factor on 2 keys
|
1 Leaf [1] Nonlinear factor on 2 keys
|
||||||
)";
|
)";
|
||||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
|
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
|
||||||
cout << s << " p(";
|
cout << (s.empty() ? "" : s + " ") << "p(";
|
||||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
|
cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue