Merge pull request #1293 from borglab/hybrid/improved-prune
Improved BayesTree pruningrelease/4.3a0
commit
4f6e4e7242
|
|
@ -11,15 +11,17 @@
|
|||
|
||||
/**
|
||||
* @file Assignment.h
|
||||
* @brief An assignment from labels to a discrete value index (size_t)
|
||||
* @brief An assignment from labels to a discrete value index (size_t)
|
||||
* @author Frank Dellaert
|
||||
* @date Feb 5, 2012
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -32,13 +34,30 @@ namespace gtsam {
|
|||
*/
|
||||
template <class L>
|
||||
class Assignment : public std::map<L, size_t> {
|
||||
/**
|
||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||
* printing.
|
||||
*
|
||||
* @param x The value passed to format.
|
||||
* @return std::string
|
||||
*/
|
||||
static std::string DefaultFormatter(const L& x) {
|
||||
std::stringstream ss;
|
||||
ss << x;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
public:
|
||||
using std::map<L, size_t>::operator=;
|
||||
|
||||
void print(const std::string& s = "Assignment: ") const {
|
||||
void print(const std::string& s = "Assignment: ",
|
||||
const std::function<std::string(L)>& labelFormatter =
|
||||
&DefaultFormatter) const {
|
||||
std::cout << s << ": ";
|
||||
for (const typename Assignment::value_type& keyValue : *this)
|
||||
std::cout << "(" << keyValue.first << ", " << keyValue.second << ")";
|
||||
for (const typename Assignment::value_type& keyValue : *this) {
|
||||
std::cout << "(" << labelFormatter(keyValue.first) << ", "
|
||||
<< keyValue.second << ")";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -119,11 +119,12 @@ void GaussianMixture::print(const std::string &s,
|
|||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
||||
RedirectCout rd;
|
||||
if (gf && !gf->empty())
|
||||
if (gf && !gf->empty()) {
|
||||
gf->print("", formatter);
|
||||
else
|
||||
return {"nullptr"};
|
||||
return rd.str();
|
||||
return rd.str();
|
||||
} else {
|
||||
return "nullptr";
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
|
||||
// The canonical decision tree factor which will get the discrete conditionals
|
||||
// added to it.
|
||||
DecisionTreeFactor dtFactor;
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
||||
dtFactor = dtFactor * f;
|
||||
}
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
// Get the decision tree of only the discrete keys
|
||||
auto discreteConditionals = this->discreteConditionals();
|
||||
const DecisionTreeFactor::shared_ptr discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
|
|
|
|||
|
|
@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
|
||||
}
|
||||
|
||||
using Base::push_back;
|
||||
|
||||
/// Get a specific Gaussian mixture by index `i`.
|
||||
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
||||
|
||||
|
|
@ -109,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
protected:
|
||||
/**
|
||||
* @brief Get all the discrete conditionals as a decision tree factor.
|
||||
*
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
*/
|
||||
DecisionTreeFactor::shared_ptr discreteConditionals() const;
|
||||
|
||||
public:
|
||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
|
|||
|
|
@ -89,12 +89,12 @@ struct HybridAssignmentData {
|
|||
gaussianbayesTree_(gbt) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operators on each node
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The HybridAssignmentData from the parent node.
|
||||
* @return HybridAssignmentData
|
||||
* @return HybridAssignmentData which is passed to the children.
|
||||
*/
|
||||
static HybridAssignmentData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& node,
|
||||
|
|
@ -144,4 +144,61 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->roots_.at(0)->conditional()->inner());
|
||||
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDiscreteFactor;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
* before visiting the node's children.
|
||||
*
|
||||
* @param node The current node being visited.
|
||||
* @param parentData The data from the parent node.
|
||||
* @return HybridPrunerData which is passed to the children.
|
||||
*/
|
||||
static HybridPrunerData AssignmentPreOrderVisitor(
|
||||
const HybridBayesTree::sharedNode& clique,
|
||||
HybridPrunerData& parentData) {
|
||||
// Get the conditional
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// If conditional is hybrid, we prune it.
|
||||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
// Check if the number of discrete keys match,
|
||||
// else we get an assignment error.
|
||||
// TODO(Varun) Update prune method to handle assignment subset?
|
||||
if (gaussianMixture->discreteKeys() ==
|
||||
parentData.prunedDiscreteFactor.discreteKeys()) {
|
||||
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
||||
}
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
TbbOpenMPMixedScope threadLimiter;
|
||||
treeTraversal::DepthFirstForestParallel(
|
||||
*this, rootData, HybridPrunerData::AssignmentPreOrderVisitor,
|
||||
visitorPost);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -88,6 +88,13 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
*/
|
||||
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const size_t maxNumberLeaves);
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridGaussianFactorGraph;
|
||||
|
||||
/**
|
||||
* Hybrid Conditional Density
|
||||
*
|
||||
|
|
|
|||
|
|
@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
|||
push_hybrid(p);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all the discrete keys in the factor graph.
|
||||
const KeySet discreteKeys() const {
|
||||
KeySet discrete_keys;
|
||||
for (auto& factor : factors_) {
|
||||
for (const DiscreteKey& k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
return discrete_keys;
|
||||
}
|
||||
|
||||
/// Get all the continuous keys in the factor graph.
|
||||
const KeySet continuousKeys() const {
|
||||
KeySet keys;
|
||||
for (auto& factor : factors_) {
|
||||
for (const Key& key : factor->continuousKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals(
|
|||
}
|
||||
|
||||
} else if (f->isContinuous()) {
|
||||
deferredFactors.push_back(
|
||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
|
||||
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
deferredFactors.push_back(gf->inner());
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
deferredFactors.push_back(cg->asGaussian());
|
||||
}
|
||||
|
||||
} else if (f->isDiscrete()) {
|
||||
// Don't do anything for discrete-only factors
|
||||
|
|
@ -404,31 +408,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
|||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
|
||||
KeySet discrete_keys;
|
||||
for (auto &factor : factors_) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
return discrete_keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
|
||||
KeySet keys;
|
||||
for (auto &factor : factors_) {
|
||||
for (const Key &key : factor->continuousKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||
KeySet discrete_keys = getDiscreteKeys();
|
||||
KeySet discrete_keys = discreteKeys();
|
||||
for (auto &factor : factors_) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
|
|
|
|||
|
|
@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
}
|
||||
}
|
||||
|
||||
/// Get all the discrete keys in the factor graph.
|
||||
const KeySet getDiscreteKeys() const;
|
||||
|
||||
/// Get all the continuous keys in the factor graph.
|
||||
const KeySet getContinuousKeys() const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
|
|
|
|||
|
|
@ -14,9 +14,10 @@
|
|||
* @date March 31, 2022
|
||||
* @author Fan Jiang
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
|
@ -41,6 +42,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree)
|
|||
void HybridGaussianISAM::updateInternal(
|
||||
const HybridGaussianFactorGraph& newFactors,
|
||||
HybridBayesTree::Cliques* orphans,
|
||||
const boost::optional<size_t>& maxNrLeaves,
|
||||
const boost::optional<Ordering>& ordering,
|
||||
const HybridBayesTree::Eliminate& function) {
|
||||
// Remove the contaminated part of the Bayes tree
|
||||
|
|
@ -60,23 +62,24 @@ void HybridGaussianISAM::updateInternal(
|
|||
for (const sharedClique& orphan : *orphans)
|
||||
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
|
||||
|
||||
KeySet allDiscrete;
|
||||
for (auto& factor : factors) {
|
||||
for (auto& k : factor->discreteKeys()) {
|
||||
allDiscrete.insert(k.first);
|
||||
}
|
||||
}
|
||||
// Get all the discrete keys from the factors
|
||||
KeySet allDiscrete = factors.discreteKeys();
|
||||
|
||||
// Create KeyVector with continuous keys followed by discrete keys.
|
||||
KeyVector newKeysDiscreteLast;
|
||||
// Insert continuous keys first.
|
||||
for (auto& k : newFactorKeys) {
|
||||
if (!allDiscrete.exists(k)) {
|
||||
newKeysDiscreteLast.push_back(k);
|
||||
}
|
||||
}
|
||||
// Insert discrete keys at the end
|
||||
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||
std::back_inserter(newKeysDiscreteLast));
|
||||
|
||||
// Get an ordering where the new keys are eliminated last
|
||||
const VariableIndex index(factors);
|
||||
|
||||
Ordering elimination_ordering;
|
||||
if (ordering) {
|
||||
elimination_ordering = *ordering;
|
||||
|
|
@ -91,6 +94,10 @@ void HybridGaussianISAM::updateInternal(
|
|||
HybridBayesTree::shared_ptr bayesTree =
|
||||
factors.eliminateMultifrontal(elimination_ordering, function, index);
|
||||
|
||||
if (maxNrLeaves) {
|
||||
bayesTree->prune(*maxNrLeaves);
|
||||
}
|
||||
|
||||
// Re-add into Bayes tree data structures
|
||||
this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(),
|
||||
bayesTree->roots().end());
|
||||
|
|
@ -99,61 +106,11 @@ void HybridGaussianISAM::updateInternal(
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
|
||||
const boost::optional<size_t>& maxNrLeaves,
|
||||
const boost::optional<Ordering>& ordering,
|
||||
const HybridBayesTree::Eliminate& function) {
|
||||
Cliques orphans;
|
||||
this->updateInternal(newFactors, &orphans, ordering, function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Check if `b` is a subset of `a`.
|
||||
* Non-const since they need to be sorted.
|
||||
*
|
||||
* @param a KeyVector
|
||||
* @param b KeyVector
|
||||
* @return True if the keys of b is a subset of a, else false.
|
||||
*/
|
||||
bool IsSubset(KeyVector a, KeyVector b) {
|
||||
std::sort(a.begin(), a.end());
|
||||
std::sort(b.begin(), b.end());
|
||||
return std::includes(a.begin(), a.end(), b.begin(), b.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
|
||||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->clique(root)->conditional()->inner());
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
|
||||
std::vector<gtsam::Key> prunedKeys;
|
||||
for (auto&& clique : nodes()) {
|
||||
// The cliques can be repeated for each frontal so we record it in
|
||||
// prunedKeys and check if we have already pruned a particular clique.
|
||||
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
|
||||
prunedKeys.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add all the keys of the current clique to be pruned to prunedKeys
|
||||
for (auto&& key : clique.second->conditional()->frontals()) {
|
||||
prunedKeys.push_back(key);
|
||||
}
|
||||
|
||||
// Convert parents() to a KeyVector for comparison
|
||||
KeyVector parents;
|
||||
for (auto&& parent : clique.second->conditional()->parents()) {
|
||||
parents.push_back(parent);
|
||||
}
|
||||
|
||||
if (IsSubset(parents, decisionTree->keys())) {
|
||||
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
clique.second->conditional()->inner());
|
||||
|
||||
gaussianMixture->prune(prunedDiscreteFactor);
|
||||
}
|
||||
}
|
||||
this->updateInternal(newFactors, &orphans, maxNrLeaves, ordering, function);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
|
|||
void updateInternal(
|
||||
const HybridGaussianFactorGraph& newFactors,
|
||||
HybridBayesTree::Cliques* orphans,
|
||||
const boost::optional<size_t>& maxNrLeaves = boost::none,
|
||||
const boost::optional<Ordering>& ordering = boost::none,
|
||||
const HybridBayesTree::Eliminate& function =
|
||||
HybridBayesTree::EliminationTraitsType::DefaultEliminate);
|
||||
|
|
@ -57,20 +58,15 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
|
|||
* @brief Perform update step with new factors.
|
||||
*
|
||||
* @param newFactors Factor graph of new factors to add and eliminate.
|
||||
* @param maxNrLeaves The maximum number of leaves to keep after pruning.
|
||||
* @param ordering Custom elimination ordering.
|
||||
* @param function Elimination function.
|
||||
*/
|
||||
void update(const HybridGaussianFactorGraph& newFactors,
|
||||
const boost::optional<size_t>& maxNrLeaves = boost::none,
|
||||
const boost::optional<Ordering>& ordering = boost::none,
|
||||
const HybridBayesTree::Eliminate& function =
|
||||
HybridBayesTree::EliminationTraitsType::DefaultEliminate);
|
||||
|
||||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param root The root key in the discrete conditional decision tree.
|
||||
* @param maxNumberLeaves
|
||||
*/
|
||||
void prune(const Key& root, const size_t maxNumberLeaves);
|
||||
};
|
||||
|
||||
/// traits
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ void HybridNonlinearISAM::saveGraph(const string& s,
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
|
||||
const Values& initialValues) {
|
||||
const Values& initialValues,
|
||||
const boost::optional<size_t>& maxNrLeaves,
|
||||
const boost::optional<Ordering>& ordering) {
|
||||
if (newFactors.size() > 0) {
|
||||
// Reorder and relinearize every reorderInterval updates
|
||||
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
|
||||
|
|
@ -51,7 +53,8 @@ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
|
|||
newFactors.linearize(linPoint_);
|
||||
|
||||
// Update ISAM
|
||||
isam_.update(*linearizedNewFactors, boost::none, eliminationFunction_);
|
||||
isam_.update(*linearizedNewFactors, maxNrLeaves, ordering,
|
||||
eliminationFunction_);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +69,7 @@ void HybridNonlinearISAM::reorder_relinearize() {
|
|||
// Just recreate the whole BayesTree
|
||||
// TODO: allow for constrained ordering here
|
||||
// TODO: decouple relinearization and reordering to avoid
|
||||
isam_.update(*factors_.linearize(newLinPoint), boost::none,
|
||||
isam_.update(*factors_.linearize(newLinPoint), boost::none, boost::none,
|
||||
eliminationFunction_);
|
||||
|
||||
// Update linearization point
|
||||
|
|
|
|||
|
|
@ -82,12 +82,9 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
|||
/**
|
||||
* @brief Prune the underlying Bayes tree.
|
||||
*
|
||||
* @param root The root key in the discrete conditional decision tree.
|
||||
* @param maxNumberLeaves
|
||||
* @param maxNumberLeaves The max number of leaf nodes to keep.
|
||||
*/
|
||||
void prune(const Key& root, const size_t maxNumberLeaves) {
|
||||
isam_.prune(root, maxNumberLeaves);
|
||||
}
|
||||
void prune(const size_t maxNumberLeaves) { isam_.prune(maxNumberLeaves); }
|
||||
|
||||
/** Return the current linearization point */
|
||||
const Values& getLinearizationPoint() const { return linPoint_; }
|
||||
|
|
@ -121,7 +118,9 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
|||
|
||||
/** Add new factors along with their initial linearization points */
|
||||
void update(const HybridNonlinearFactorGraph& newFactors,
|
||||
const Values& initialValues);
|
||||
const Values& initialValues,
|
||||
const boost::optional<size_t>& maxNrLeaves = boost::none,
|
||||
const boost::optional<Ordering>& ordering = boost::none);
|
||||
|
||||
/** Relinearization and reordering of variables */
|
||||
void reorder_relinearize();
|
||||
|
|
|
|||
|
|
@ -115,7 +115,6 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
|
|||
/* ***************************************************************************
|
||||
*/
|
||||
using MotionModel = BetweenFactor<double>;
|
||||
// using MotionMixture = MixtureFactor<MotionModel>;
|
||||
|
||||
// Test fixture with switching network.
|
||||
struct Switching {
|
||||
|
|
@ -125,7 +124,13 @@ struct Switching {
|
|||
HybridGaussianFactorGraph linearizedFactorGraph;
|
||||
Values linearizationPoint;
|
||||
|
||||
/// Create with given number of time steps.
|
||||
/**
|
||||
* @brief Create with given number of time steps.
|
||||
*
|
||||
* @param K The total number of timesteps.
|
||||
* @param between_sigma The stddev between poses.
|
||||
* @param prior_sigma The stddev on priors (also used for measurements).
|
||||
*/
|
||||
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1)
|
||||
: K(K) {
|
||||
// Create DiscreteKeys for binary K modes, modes[0] will not be used.
|
||||
|
|
@ -166,6 +171,8 @@ struct Switching {
|
|||
linearizationPoint.insert<double>(X(k), static_cast<double>(k));
|
||||
}
|
||||
|
||||
// The ground truth is robot moving forward
|
||||
// and one less than the linearization point
|
||||
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,20 @@ TEST(HybridBayesNet, Creation) {
|
|||
EXPECT(df.equals(expected));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test adding a bayes net to another one.
|
||||
TEST(HybridBayesNet, Add) {
|
||||
HybridBayesNet bayesNet;
|
||||
|
||||
bayesNet.add(Asia, "99/1");
|
||||
|
||||
DiscreteConditional expected(Asia, "99/1");
|
||||
|
||||
HybridBayesNet other;
|
||||
other.push_back(bayesNet);
|
||||
EXPECT(bayesNet.equals(other));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test choosing an assignment of conditionals
|
||||
TEST(HybridBayesNet, Choose) {
|
||||
|
|
@ -169,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net pruning
|
||||
TEST(HybridBayesNet, Prune) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
auto prunedBayesNet = hybridBayesNet->prune(2);
|
||||
HybridValues pruned_delta = prunedBayesNet.optimize();
|
||||
|
||||
EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
|
||||
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
|
|
|
|||
|
|
@ -500,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(HybridGaussianFactorGraph, optimize) {
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
||||
|
|
@ -521,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) {
|
|||
|
||||
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test adding of gaussian conditional and re-elimination.
|
||||
TEST(HybridGaussianFactorGraph, Conditionals) {
|
||||
Switching switching(4);
|
||||
HybridGaussianFactorGraph hfg;
|
||||
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(1));
|
||||
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);
|
||||
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
|
||||
hfg.push_back(*bayes_net);
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
|
||||
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
|
||||
ordering.push_back(X(2));
|
||||
ordering.push_back(X(3));
|
||||
ordering.push_back(M(1));
|
||||
ordering.push_back(M(2));
|
||||
|
||||
bayes_net = hfg.eliminateSequential(ordering);
|
||||
|
||||
HybridValues result = bayes_net->optimize();
|
||||
|
||||
Values expected_continuous;
|
||||
expected_continuous.insert<double>(X(1), 0);
|
||||
expected_continuous.insert<double>(X(2), 1);
|
||||
expected_continuous.insert<double>(X(3), 2);
|
||||
expected_continuous.insert<double>(X(4), 4);
|
||||
Values result_continuous =
|
||||
switching.linearizationPoint.retract(result.continuous());
|
||||
EXPECT(assert_equal(expected_continuous, result_continuous));
|
||||
|
||||
DiscreteValues expected_discrete;
|
||||
expected_discrete[M(1)] = 1;
|
||||
expected_discrete[M(2)] = 1;
|
||||
EXPECT(assert_equal(expected_discrete, result.discrete()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
@ -235,7 +235,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
|||
size_t maxNrLeaves = 5;
|
||||
incrementalHybrid.update(graph1);
|
||||
|
||||
incrementalHybrid.prune(M(3), maxNrLeaves);
|
||||
incrementalHybrid.prune(maxNrLeaves);
|
||||
|
||||
/*
|
||||
unpruned factor is:
|
||||
|
|
@ -329,7 +329,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
// Run update with pruning
|
||||
size_t maxComponents = 5;
|
||||
incrementalHybrid.update(graph1);
|
||||
incrementalHybrid.prune(M(3), maxComponents);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
|
|
@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
|
||||
// Run update with pruning a second time.
|
||||
incrementalHybrid.update(graph2);
|
||||
incrementalHybrid.prune(M(4), maxComponents);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
|
|
@ -496,7 +496,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
// The MHS at this point should be a 2 level tree on (1, 2).
|
||||
// 1 has 2 choices, and 2 has 4 choices.
|
||||
inc.update(gfg);
|
||||
inc.prune(M(2), 2);
|
||||
inc.prune(2);
|
||||
|
||||
/*************** Run Round 4 ***************/
|
||||
// Add odometry factor with discrete modes.
|
||||
|
|
@ -531,7 +531,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
|
||||
// Keep pruning!
|
||||
inc.update(gfg);
|
||||
inc.prune(M(3), 3);
|
||||
inc.prune(3);
|
||||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
|
|
@ -256,7 +256,7 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
|||
incrementalHybrid.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(M(3), maxNrLeaves);
|
||||
bayesTree.prune(maxNrLeaves);
|
||||
|
||||
/*
|
||||
unpruned factor is:
|
||||
|
|
@ -355,7 +355,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
incrementalHybrid.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(M(3), maxComponents);
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
|
|
@ -380,7 +380,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
incrementalHybrid.update(graph2, initial);
|
||||
bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(M(4), maxComponents);
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
|
|
@ -482,8 +482,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
still = boost::make_shared<PlanarMotionModel>(W(1), W(2), Pose2(0, 0, 0),
|
||||
noise_model);
|
||||
moving =
|
||||
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry,
|
||||
noise_model);
|
||||
boost::make_shared<PlanarMotionModel>(W(1), W(2), odometry, noise_model);
|
||||
components = {moving, still};
|
||||
mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(2), 2)}, components);
|
||||
|
|
@ -515,7 +514,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
// The MHS at this point should be a 2 level tree on (1, 2).
|
||||
// 1 has 2 choices, and 2 has 4 choices.
|
||||
inc.update(fg, initial);
|
||||
inc.prune(M(2), 2);
|
||||
inc.prune(2);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
|
@ -526,8 +525,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
still = boost::make_shared<PlanarMotionModel>(W(2), W(3), Pose2(0, 0, 0),
|
||||
noise_model);
|
||||
moving =
|
||||
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry,
|
||||
noise_model);
|
||||
boost::make_shared<PlanarMotionModel>(W(2), W(3), odometry, noise_model);
|
||||
components = {moving, still};
|
||||
mixtureFactor = boost::make_shared<MixtureFactor>(
|
||||
contKeys, DiscreteKeys{gtsam::DiscreteKey(M(3), 2)}, components);
|
||||
|
|
@ -551,7 +549,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
|
||||
// Keep pruning!
|
||||
inc.update(fg, initial);
|
||||
inc.prune(M(3), 3);
|
||||
inc.prune(3);
|
||||
|
||||
fg = HybridNonlinearFactorGraph();
|
||||
initial = Values();
|
||||
|
|
@ -560,8 +558,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree =
|
||||
bayesTree[M(3)]->conditional()->asDiscreteConditional();
|
||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ namespace gtsam {
|
|||
// Forward declarations
|
||||
template<class FACTOR> class FactorGraph;
|
||||
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
|
||||
class HybridBayesTreeClique;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/** clique statistics */
|
||||
|
|
|
|||
Loading…
Reference in New Issue