Merge pull request #1306 from borglab/varun/test-hybrid-estimation

Hybrid Updates
release/4.3a0
Varun Agrawal 2022-10-14 09:02:33 -04:00 committed by GitHub
commit cd69c51e86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 113 additions and 100 deletions

View File

@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) {
clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 =

View File

@ -128,25 +128,81 @@ void GaussianMixture::print(const std::string &s,
});
}
/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&decisionTree](
/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else {
return conditional;
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
gaussianMixtureKeySet.begin(),
gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (decisionTree(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
}
};
return pruner;
}
/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(decisionTree);
auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;

View File

@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture
*/
Sum asGaussianFactorGraphTree() const;
/**
* @brief Helper function to get the pruner functor.
*
* @param decisionTree The pruned discrete probability decision tree.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree);
public:
/// @name Constructors
/// @{

View File

@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s,
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf)
if (gf && !gf->empty()) {
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
return rd.str();
} else {
return "nullptr";
}
});
std::cout << "}" << std::endl;
}

View File

@ -22,14 +22,6 @@
namespace gtsam {
/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
@ -66,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
HybridBayesNet prunedBayesNetFragment;
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};
// Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
GaussianMixture::shared_ptr gaussianMixture =
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
if (gaussianMixture) {
// We may have mixtures with less discrete keys than discreteFactor so we
// skip those since the label assignment does not exist.
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
if (gmKeySet != dfKeySet) {
// Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture));
continue;
}
// Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());
// Convert from boost::iterator_range to KeyVector
// so we can pass it to constructor.
KeyVector frontals(gaussianMixture->frontals().begin(),
gaussianMixture->frontals().end()),
parents(gaussianMixture->parents().begin(),
gaussianMixture->parents().end());
// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
frontals, parents, discreteKeys, prunedTree);
// Make a copy of the gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*discreteFactor);
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
@ -173,7 +122,7 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}
/* *******************************************************************************/
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE
DiscreteBayesNet discrete_bn;
@ -190,7 +139,7 @@ HybridValues HybridBayesNet::optimize() const {
return HybridValues(mpe, gbn.optimize());
}
/* *******************************************************************************/
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize();

View File

@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}
: prunedDecisionTree(prunedDecisionTree) {}
/**
* @brief A function used during tree traversal that operates on each node
@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
gaussianMixture->prune(parentData.prunedDecisionTree);
}
return parentData;
}
};
HybridPrunerData rootData(prunedDiscreteFactor, 0);
HybridPrunerData rootData(prunedDecisionTree, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -50,9 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys),
isContinuous_(true),
continuousKeys_(keys) {}
: Base(keys), isContinuous_(true), continuousKeys_(keys) {}
/* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
@ -101,7 +99,6 @@ void HybridFactor::print(const std::string &s,
if (d < discreteKeys_.size() - 1) {
std::cout << " ";
}
}
std::cout << "]";
}

View File

@ -219,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
// Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional
keysOfEliminated = result.first->keys();
}
// keysOfEliminated of the GaussianConditional
if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys();
}

View File

@ -59,8 +59,9 @@ void HybridGaussianISAM::updateInternal(
factors += newFactors;
// Add the orphaned subtrees
for (const sharedClique& orphan : *orphans)
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
for (const sharedClique& orphan : *orphans) {
factors += boost::make_shared<BayesTreeOrphanWrapper<Node>>(orphan);
}
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys();

View File

@ -99,7 +99,7 @@ void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n");
isam_.print("HybridGaussianISAM:\n", keyFormatter);
linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter);
}

View File

@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
EXPECT_LONGS_EQUAL(
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(

View File

@ -12,7 +12,7 @@
/**
* @file testHybridNonlinearISAM.cpp
* @brief Unit tests for nonlinear incremental inference
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
* @author Varun Agrawal, Fan Jiang, Frank Dellaert
* @date Jan 2021
*/
@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
EXPECT_LONGS_EQUAL(
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
@ -432,9 +432,9 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// Don't run update now since we don't have discrete variables involved.
/*************** Run Round 2 ***************/
using PlanarMotionModel = BetweenFactor<Pose2>;
/*************** Run Round 2 ***************/
// Add odometry factor with discrete modes.
Pose2 odometry(1.0, 0.0, 0.0);
KeyVector contKeys = {W(0), W(1)};