Merge pull request #1575 from borglab/hybrid-tablefactor-2
commit
ba7c077a25
|
@ -93,7 +93,8 @@ namespace gtsam {
|
|||
/// print
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||
std::cout << s << " Leaf [" << nrAssignments() << "] "
|
||||
<< valueFormatter(constant_) << std::endl;
|
||||
}
|
||||
|
||||
/** Write graphviz format to stream `os`. */
|
||||
|
@ -827,6 +828,16 @@ namespace gtsam {
|
|||
return total;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template <typename L, typename Y>
|
||||
size_t DecisionTree<L, Y>::nrAssignments() const {
|
||||
size_t n = 0;
|
||||
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
|
||||
n += leaf.nrAssignments();
|
||||
});
|
||||
return n;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// fold is just done with a visit
|
||||
template <typename L, typename Y>
|
||||
|
|
|
@ -320,6 +320,42 @@ namespace gtsam {
|
|||
/// Return the number of leaves in the tree.
|
||||
size_t nrLeaves() const;
|
||||
|
||||
/**
|
||||
* @brief This is a convenience function which returns the total number of
|
||||
* leaf assignments in the decision tree.
|
||||
* This function is not used for anymajor operations within the discrete
|
||||
* factor graph framework.
|
||||
*
|
||||
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
|
||||
* binary tree each leaf has 2 assignments. This includes counts removed
|
||||
* from implicit pruning hence, it will always be >= nrLeaves().
|
||||
*
|
||||
* E.g. we have a decision tree as below, where each node has 2 branches:
|
||||
*
|
||||
* Choice(m1)
|
||||
* 0 Choice(m0)
|
||||
* 0 0 Leaf 0.0
|
||||
* 0 1 Leaf 0.0
|
||||
* 1 Choice(m0)
|
||||
* 1 0 Leaf 1.0
|
||||
* 1 1 Leaf 2.0
|
||||
*
|
||||
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
|
||||
* and 4 leaves.
|
||||
*
|
||||
* In the pruned form, the number of assignments is still 4 but the number
|
||||
* of leaves is now 3, as below:
|
||||
*
|
||||
* Choice(m1)
|
||||
* 0 Leaf 0.0
|
||||
* 1 Choice(m0)
|
||||
* 1 0 Leaf 1.0
|
||||
* 1 1 Leaf 2.0
|
||||
*
|
||||
* @return size_t
|
||||
*/
|
||||
size_t nrAssignments() const;
|
||||
|
||||
/**
|
||||
* @brief Fold a binary function over the tree, returning accumulator.
|
||||
*
|
||||
|
|
|
@ -101,6 +101,14 @@ namespace gtsam {
|
|||
return DecisionTreeFactor(keys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
|
||||
// apply operand
|
||||
ADT result = ADT::apply(op);
|
||||
// Make a new factor
|
||||
return DecisionTreeFactor(discreteKeys(), result);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||
size_t nrFrontals, ADT::Binary op) const {
|
||||
|
|
|
@ -182,6 +182,12 @@ namespace gtsam {
|
|||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* Apply unary operator (*this) "op" f
|
||||
* @param op a unary operator that operates on AlgebraicDecisionTree
|
||||
*/
|
||||
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
|
||||
|
||||
/**
|
||||
* Apply binary operator (*this) "op" f
|
||||
* @param f the second argument for op
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <iomanip>
|
||||
|
||||
|
|
|
@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> discreteProbs;
|
||||
|
||||
// The canonical decision tree factor which will get
|
||||
// the discrete conditionals added to it.
|
||||
DecisionTreeFactor discreteProbsFactor;
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||
DecisionTreeFactor f(*conditional->asDiscrete());
|
||||
discreteProbsFactor = discreteProbsFactor * f;
|
||||
}
|
||||
}
|
||||
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
|
@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::updateDiscreteConditionals(
|
||||
const DecisionTreeFactor &prunedDiscreteProbs) {
|
||||
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
|
||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||
size_t maxNrLeaves) {
|
||||
// Get the joint distribution of only the discrete keys
|
||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
// The joint discrete probability.
|
||||
DiscreteConditional discreteProbs;
|
||||
|
||||
std::vector<size_t> discrete_factor_idxs;
|
||||
// Record frontal keys so we can maintain ordering
|
||||
Ordering discrete_frontals;
|
||||
|
||||
// Loop with index since we need it later.
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
auto conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
auto discrete = conditional->asDiscrete();
|
||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||
|
||||
// Convert pointer from conditional to factor
|
||||
auto discreteTree =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||
// Apply prunerFunc to the underlying AlgebraicDecisionTree
|
||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
||||
|
||||
gttic_(HybridBayesNet_MakeConditional);
|
||||
// Create the new (hybrid) conditional
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
|
||||
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
|
||||
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
|
||||
gttoc_(HybridBayesNet_MakeConditional);
|
||||
|
||||
// Add it back to the BayesNet
|
||||
this->at(i) = conditional;
|
||||
Ordering conditional_keys(conditional->frontals());
|
||||
discrete_frontals += conditional_keys;
|
||||
discrete_factor_idxs.push_back(i);
|
||||
}
|
||||
}
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteProbs.prune(maxNrLeaves);
|
||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
|
||||
// Eliminate joint probability back into conditionals
|
||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||
|
||||
// Assign pruned discrete conditionals back at the correct indices.
|
||||
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||
size_t idx = discrete_factor_idxs.at(i);
|
||||
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||
}
|
||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
|
||||
return prunedDiscreteProbs;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
// Get the decision tree of only the discrete keys
|
||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
DecisionTreeFactor::shared_ptr discreteConditionals =
|
||||
this->discreteConditionals();
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteConditionals->prune(maxNrLeaves);
|
||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
DecisionTreeFactor prunedDiscreteProbs =
|
||||
this->pruneDiscreteConditionals(maxNrLeaves);
|
||||
|
||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
/* To prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
*
|
||||
|
|
|
@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||
*
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
*/
|
||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||
|
||||
/**
|
||||
* @brief Sample from an incomplete BayesNet, given missing variables.
|
||||
*
|
||||
|
@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
|
||||
private:
|
||||
/**
|
||||
* @brief Update the discrete conditionals with the pruned versions.
|
||||
* @brief Prune all the discrete conditionals.
|
||||
*
|
||||
* @param prunedDiscreteProbs
|
||||
* @param maxNrLeaves
|
||||
*/
|
||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
|
||||
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
||||
|
||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/inference/Factor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
* @date January, 2023
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -26,7 +25,7 @@ namespace gtsam {
|
|||
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
||||
std::set<DiscreteKey> keys;
|
||||
for (auto& factor : factors_) {
|
||||
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
|
@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
|
|||
for (const Key& key : p->continuousKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
keys.insert(p->keys().begin(), p->keys().end());
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
|
|
|
@ -48,8 +48,6 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// #define HYBRID_TIMING
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
|
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
// TODO(dellaert): in C++20, we can use std::visit.
|
||||
continue;
|
||||
}
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// Don't do anything for discrete-only factors
|
||||
// since we want to eliminate continuous values only.
|
||||
continue;
|
||||
|
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
DiscreteFactorGraph dfg;
|
||||
|
||||
for (auto &f : factors) {
|
||||
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
dfg.push_back(dtf);
|
||||
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
dfg.push_back(df);
|
||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||
// Ignore orphaned clique.
|
||||
// TODO(dellaert): is this correct? If so explain here.
|
||||
|
@ -262,6 +260,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
};
|
||||
|
||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||
|
||||
return {
|
||||
std::make_shared<HybridConditional>(gaussianMixture),
|
||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
||||
|
@ -348,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
// When the number of assignments is large we may encounter stack overflows.
|
||||
// However this is also the case with iSAM2, so no pressure :)
|
||||
|
||||
// PREPROCESS: Identify the nature of the current elimination
|
||||
|
||||
// TODO(dellaert): just check the factors:
|
||||
// Check the factors:
|
||||
// 1. if all factors are discrete, then we can do discrete elimination:
|
||||
// 2. if all factors are continuous, then we can do continuous elimination:
|
||||
// 3. if not, we do hybrid elimination:
|
||||
|
||||
// First, identify the separator keys, i.e. all keys that are not frontal.
|
||||
KeySet separatorKeys;
|
||||
bool only_discrete = true, only_continuous = true;
|
||||
for (auto &&factor : factors) {
|
||||
separatorKeys.insert(factor->begin(), factor->end());
|
||||
}
|
||||
// remove frontals from separator
|
||||
for (auto &k : frontalKeys) {
|
||||
separatorKeys.erase(k);
|
||||
}
|
||||
|
||||
// Build a map from keys to DiscreteKeys
|
||||
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
|
||||
|
||||
// Fill in discrete frontals and continuous frontals.
|
||||
std::set<DiscreteKey> discreteFrontals;
|
||||
KeySet continuousFrontals;
|
||||
for (auto &k : frontalKeys) {
|
||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
||||
} else {
|
||||
continuousFrontals.insert(k);
|
||||
if (auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
if (hybrid_factor->isDiscrete()) {
|
||||
only_continuous = false;
|
||||
} else if (hybrid_factor->isContinuous()) {
|
||||
only_discrete = false;
|
||||
} else if (hybrid_factor->isHybrid()) {
|
||||
only_continuous = false;
|
||||
only_discrete = false;
|
||||
}
|
||||
} else if (auto cont_factor =
|
||||
std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
only_discrete = false;
|
||||
} else if (auto discrete_factor =
|
||||
std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
only_continuous = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill in discrete discrete separator keys and continuous separator keys.
|
||||
std::set<DiscreteKey> discreteSeparatorSet;
|
||||
KeyVector continuousSeparator;
|
||||
for (auto &k : separatorKeys) {
|
||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
||||
} else {
|
||||
continuousSeparator.push_back(k);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have any continuous keys:
|
||||
const bool discrete_only =
|
||||
continuousFrontals.empty() && continuousSeparator.empty();
|
||||
|
||||
// NOTE: We should really defer the product here because of pruning
|
||||
|
||||
if (discrete_only) {
|
||||
if (only_discrete) {
|
||||
// Case 1: we are only dealing with discrete
|
||||
return discreteElimination(factors, frontalKeys);
|
||||
} else if (mapFromKeyToDiscreteKey.empty()) {
|
||||
} else if (only_continuous) {
|
||||
// Case 2: we are only dealing with continuous
|
||||
return continuousElimination(factors, frontalKeys);
|
||||
} else {
|
||||
// Case 3: We are now in the hybrid land!
|
||||
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
|
||||
|
||||
// Find all the keys in the set of continuous keys
|
||||
// which are not in the frontal keys. This is our continuous separator.
|
||||
KeyVector continuousSeparator;
|
||||
auto continuousKeySet = factors.continuousKeySet();
|
||||
std::set_difference(
|
||||
continuousKeySet.begin(), continuousKeySet.end(),
|
||||
frontalKeysSet.begin(), frontalKeysSet.end(),
|
||||
std::inserter(continuousSeparator, continuousSeparator.begin()));
|
||||
|
||||
// Similarly for the discrete separator.
|
||||
KeySet discreteSeparatorSet;
|
||||
std::set<DiscreteKey> discreteSeparator;
|
||||
auto discreteKeySet = factors.discreteKeySet();
|
||||
std::set_difference(
|
||||
discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
|
||||
frontalKeysSet.end(),
|
||||
std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
|
||||
// Convert from set of keys to set of DiscreteKeys
|
||||
auto discreteKeyMap = factors.discreteKeyMap();
|
||||
for (auto key : discreteSeparatorSet) {
|
||||
discreteSeparator.insert(discreteKeyMap.at(key));
|
||||
}
|
||||
|
||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||
discreteSeparatorSet);
|
||||
discreteSeparator);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -429,7 +432,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
// Add the gaussian factor error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
} else {
|
||||
|
|
|
@ -40,6 +40,7 @@ class HybridEliminationTree;
|
|||
class HybridBayesTree;
|
||||
class HybridJunctionTree;
|
||||
class DecisionTreeFactor;
|
||||
class TableFactor;
|
||||
class JacobianFactor;
|
||||
class HybridValues;
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
|
|||
for (auto& k : hf->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
}
|
||||
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
for (auto& k : hf->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
|
|||
Data rootData(0);
|
||||
rootData.junctionTreeNode =
|
||||
std::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
||||
// the junction tree roots
|
||||
// the junction tree roots
|
||||
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
||||
Data::ConstructorTraversalVisitorPre,
|
||||
Data::ConstructorTraversalVisitorPost);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||
|
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
|||
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||
linearFG->push_back(gf);
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// If discrete-only: doesn't need linearization.
|
||||
linearFG->push_back(f);
|
||||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
|
|
|
@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
|||
addConditionals(graph, hybridBayesNet_, ordering);
|
||||
|
||||
// Eliminate.
|
||||
auto bayesNetFragment = graph.eliminateSequential(ordering);
|
||||
HybridBayesNet::shared_ptr bayesNetFragment =
|
||||
graph.eliminateSequential(ordering);
|
||||
|
||||
/// Prune
|
||||
if (maxNrLeaves) {
|
||||
|
@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
|||
HybridGaussianFactorGraph graph(originalGraph);
|
||||
HybridBayesNet hybridBayesNet(originalHybridBayesNet);
|
||||
|
||||
// If we are not at the first iteration, means we have conditionals to add.
|
||||
// If hybridBayesNet is not empty,
|
||||
// it means we have conditionals to add to the factor graph.
|
||||
if (!hybridBayesNet.empty()) {
|
||||
// We add all relevant conditional mixtures on the last continuous variable
|
||||
// in the previous `hybridBayesNet` to the graph
|
||||
|
|
|
@ -202,31 +202,16 @@ struct Switching {
|
|||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
||||
* E.g. if K=4, we want M0, M1 and M2.
|
||||
*
|
||||
* @param fg The nonlinear factor graph to which the mode chain is added.
|
||||
* @param fg The factor graph to which the mode chain is added.
|
||||
*/
|
||||
void addModeChain(HybridNonlinearFactorGraph *fg,
|
||||
template <typename FACTORGRAPH>
|
||||
void addModeChain(FACTORGRAPH *fg,
|
||||
std::string discrete_transition_prob = "1/2 3/2") {
|
||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
for (size_t k = 0; k < K - 2; k++) {
|
||||
auto parents = {modes[k]};
|
||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||
discrete_transition_prob);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2).
|
||||
* E.g. if K=4, we want M0, M1 and M2.
|
||||
*
|
||||
* @param fg The gaussian factor graph to which the mode chain is added.
|
||||
*/
|
||||
void addModeChain(HybridGaussianFactorGraph *fg,
|
||||
std::string discrete_transition_prob = "1/2 3/2") {
|
||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
for (size_t k = 0; k < K - 2; k++) {
|
||||
auto parents = {modes[k]};
|
||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||
discrete_transition_prob);
|
||||
fg->template emplace_shared<DiscreteConditional>(
|
||||
modes[k + 1], parents, discrete_transition_prob);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
|||
std::string expected =
|
||||
R"(Hybrid [x1 x2; 1]{
|
||||
Choice(1)
|
||||
0 Leaf :
|
||||
0 Leaf [1] :
|
||||
A[x1] = [
|
||||
0;
|
||||
0
|
||||
|
@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
|
|||
b = [ 0 0 ]
|
||||
No noise model
|
||||
|
||||
1 Leaf :
|
||||
1 Leaf [1] :
|
||||
A[x1] = [
|
||||
0;
|
||||
0
|
||||
|
|
|
@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
|
|||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
||||
|
||||
// Regression test on pruned logProbability tree
|
||||
std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
|
||||
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||
|
||||
|
@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
|
|||
logProbability +=
|
||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||
|
||||
// Regression
|
||||
double density = exp(logProbability);
|
||||
EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(density,
|
||||
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||
1e-9);
|
||||
|
@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
auto discreteConditionals = posterior->discreteConditionals();
|
||||
DiscreteConditional discreteConditionals;
|
||||
for (auto&& conditional : *posterior) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discreteConditionals =
|
||||
discreteConditionals * (*conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
||||
std::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
discreteConditionals.prune(maxNrLeaves));
|
||||
|
||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
prunedDecisionTree->nrLeaves());
|
||||
|
||||
auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
|
||||
// regression
|
||||
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
DecisionTreeFactor::ADT potentials(
|
||||
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||
|
||||
// Prune!
|
||||
posterior->prune(maxNrLeaves);
|
||||
|
||||
// Functor to verify values against the original_discrete_conditionals
|
||||
// Functor to verify values against the expected_discrete_conditionals
|
||||
auto checker = [&](const Assignment<Key>& assignment,
|
||||
double probability) -> double {
|
||||
// typecast so we can use this to get probability value
|
||||
|
@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
if (prunedDecisionTree->operator()(choices) == 0) {
|
||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||
} else {
|
||||
EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
|
||||
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
||||
1e-9);
|
||||
}
|
||||
return 0.0;
|
||||
|
|
|
@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
|
|||
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto&& f : *remainingFactorGraph) {
|
||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
|
||||
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
|
||||
assert(discreteFactor);
|
||||
dfg.push_back(discreteFactor);
|
||||
}
|
||||
|
|
|
@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) {
|
|||
EXPECT(assert_equal(expected_continuous, result));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// Test approximate inference with an additional pruning step.
|
||||
TEST(HybridEstimation, ISAM) {
|
||||
size_t K = 15;
|
||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
||||
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
||||
// Ground truth discrete seq
|
||||
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
||||
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
||||
// Switching example of robot moving in 1D
|
||||
// with given measurements and equal mode priors.
|
||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
||||
HybridNonlinearISAM isam;
|
||||
HybridNonlinearFactorGraph graph;
|
||||
Values initial;
|
||||
|
||||
// gttic_(Estimation);
|
||||
|
||||
// Add the X(0) prior
|
||||
graph.push_back(switching.nonlinearFactorGraph.at(0));
|
||||
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
||||
|
||||
HybridGaussianFactorGraph linearized;
|
||||
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
// Motion Model
|
||||
graph.push_back(switching.nonlinearFactorGraph.at(k));
|
||||
// Measurement
|
||||
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
|
||||
|
||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||
|
||||
isam.update(graph, initial, 3);
|
||||
// isam.bayesTree().print("\n\n");
|
||||
|
||||
graph.resize(0);
|
||||
initial.clear();
|
||||
}
|
||||
|
||||
Values result = isam.estimate();
|
||||
DiscreteValues assignment = isam.assignment();
|
||||
|
||||
DiscreteValues expected_discrete;
|
||||
for (size_t k = 0; k < K - 1; k++) {
|
||||
expected_discrete[M(k)] = discrete_seq[k];
|
||||
}
|
||||
EXPECT(assert_equal(expected_discrete, assignment));
|
||||
|
||||
Values expected_continuous;
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
expected_continuous.insert(X(k), measurements[k]);
|
||||
}
|
||||
EXPECT(assert_equal(expected_continuous, result));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A function to get a specific 1D robot motion problem as a linearized
|
||||
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/base/utilities.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/JacobianFactor.h>
|
||||
#include <gtsam/nonlinear/PriorFactor.h>
|
||||
|
||||
using namespace std;
|
||||
|
@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
|
|||
HybridFactorGraph fg;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test if methods to get keys work as expected.
|
||||
TEST(HybridFactorGraph, Keys) {
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
||||
// Add prior on x0
|
||||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||
|
||||
// Add factor between x0 and x1
|
||||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||
|
||||
// Add a gaussian mixture factor ϕ(x1, c1)
|
||||
DiscreteKey m1(M(1), 2);
|
||||
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
|
||||
M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
|
||||
std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
|
||||
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
|
||||
|
||||
KeySet expected_continuous{X(0), X(1)};
|
||||
EXPECT(
|
||||
assert_container_equality(expected_continuous, hfg.continuousKeySet()));
|
||||
|
||||
KeySet expected_discrete{M(1)};
|
||||
EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
|||
// Test resulting posterior Bayes net has correct size:
|
||||
EXPECT_LONGS_EQUAL(8, posterior->size());
|
||||
|
||||
// TODO(dellaert): this test fails - no idea why.
|
||||
// Ratio test
|
||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||
}
|
||||
|
||||
|
|
|
@ -492,7 +492,7 @@ factor 0:
|
|||
factor 1:
|
||||
Hybrid [x0 x1; m0]{
|
||||
Choice(m0)
|
||||
0 Leaf :
|
||||
0 Leaf [1] :
|
||||
A[x0] = [
|
||||
-1
|
||||
]
|
||||
|
@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{
|
|||
b = [ -1 ]
|
||||
No noise model
|
||||
|
||||
1 Leaf :
|
||||
1 Leaf [1] :
|
||||
A[x0] = [
|
||||
-1
|
||||
]
|
||||
|
@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{
|
|||
factor 2:
|
||||
Hybrid [x1 x2; m1]{
|
||||
Choice(m1)
|
||||
0 Leaf :
|
||||
0 Leaf [1] :
|
||||
A[x1] = [
|
||||
-1
|
||||
]
|
||||
|
@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{
|
|||
b = [ -1 ]
|
||||
No noise model
|
||||
|
||||
1 Leaf :
|
||||
1 Leaf [1] :
|
||||
A[x1] = [
|
||||
-1
|
||||
]
|
||||
|
@ -550,16 +550,16 @@ factor 4:
|
|||
b = [ -10 ]
|
||||
No noise model
|
||||
factor 5: P( m0 ):
|
||||
Leaf 0.5
|
||||
Leaf [2] 0.5
|
||||
|
||||
factor 6: P( m1 | m0 ):
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf 0.33333333
|
||||
0 1 Leaf 0.6
|
||||
0 0 Leaf [1] 0.33333333
|
||||
0 1 Leaf [1] 0.6
|
||||
1 Choice(m0)
|
||||
1 0 Leaf 0.66666667
|
||||
1 1 Leaf 0.4
|
||||
1 0 Leaf [1] 0.66666667
|
||||
1 1 Leaf [1] 0.4
|
||||
|
||||
)";
|
||||
EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
|
||||
|
@ -570,13 +570,13 @@ size: 3
|
|||
conditional 0: Hybrid P( x0 | x1 m0)
|
||||
Discrete Keys = (m0, 2),
|
||||
Choice(m0)
|
||||
0 Leaf p(x0 | x1)
|
||||
0 Leaf [1] p(x0 | x1)
|
||||
R = [ 10.0499 ]
|
||||
S[x1] = [ -0.0995037 ]
|
||||
d = [ -9.85087 ]
|
||||
No noise model
|
||||
|
||||
1 Leaf p(x0 | x1)
|
||||
1 Leaf [1] p(x0 | x1)
|
||||
R = [ 10.0499 ]
|
||||
S[x1] = [ -0.0995037 ]
|
||||
d = [ -9.95037 ]
|
||||
|
@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
|
|||
Discrete Keys = (m0, 2), (m1, 2),
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf p(x1 | x2)
|
||||
0 0 Leaf [1] p(x1 | x2)
|
||||
R = [ 10.099 ]
|
||||
S[x2] = [ -0.0990196 ]
|
||||
d = [ -9.99901 ]
|
||||
No noise model
|
||||
|
||||
0 1 Leaf p(x1 | x2)
|
||||
0 1 Leaf [1] p(x1 | x2)
|
||||
R = [ 10.099 ]
|
||||
S[x2] = [ -0.0990196 ]
|
||||
d = [ -9.90098 ]
|
||||
No noise model
|
||||
|
||||
1 Choice(m0)
|
||||
1 0 Leaf p(x1 | x2)
|
||||
1 0 Leaf [1] p(x1 | x2)
|
||||
R = [ 10.099 ]
|
||||
S[x2] = [ -0.0990196 ]
|
||||
d = [ -10.098 ]
|
||||
No noise model
|
||||
|
||||
1 1 Leaf p(x1 | x2)
|
||||
1 1 Leaf [1] p(x1 | x2)
|
||||
R = [ 10.099 ]
|
||||
S[x2] = [ -0.0990196 ]
|
||||
d = [ -10 ]
|
||||
|
@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
|||
Discrete Keys = (m0, 2), (m1, 2),
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf p(x2)
|
||||
0 0 Leaf [1] p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.1489 ]
|
||||
mean: 1 elements
|
||||
x2: -1.0099
|
||||
No noise model
|
||||
|
||||
0 1 Leaf p(x2)
|
||||
0 1 Leaf [1] p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.1479 ]
|
||||
mean: 1 elements
|
||||
|
@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
|||
No noise model
|
||||
|
||||
1 Choice(m0)
|
||||
1 0 Leaf p(x2)
|
||||
1 0 Leaf [1] p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.0504 ]
|
||||
mean: 1 elements
|
||||
x2: -1.0001
|
||||
No noise model
|
||||
|
||||
1 1 Leaf p(x2)
|
||||
1 1 Leaf [1] p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.0494 ]
|
||||
mean: 1 elements
|
||||
|
|
|
@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
|
|||
R"(Hybrid [x1 x2; 1]
|
||||
MixtureFactor
|
||||
Choice(1)
|
||||
0 Leaf Nonlinear factor on 2 keys
|
||||
1 Leaf Nonlinear factor on 2 keys
|
||||
0 Leaf [1] Nonlinear factor on 2 keys
|
||||
1 Leaf [1] Nonlinear factor on 2 keys
|
||||
)";
|
||||
EXPECT(assert_print_equal(expected, mixtureFactor));
|
||||
}
|
||||
|
|
|
@ -99,7 +99,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************ */
|
||||
void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
|
||||
cout << s << " p(";
|
||||
cout << (s.empty() ? "" : s + " ") << "p(";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue