Merge pull request #1306 from borglab/varun/test-hybrid-estimation

Hybrid Updates
release/4.3a0
Varun Agrawal 2022-10-14 09:02:33 -04:00 committed by GitHub
commit cd69c51e86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 113 additions and 100 deletions

View File

@ -159,6 +159,10 @@ TEST(DiscreteBayesTree, ThinTree) {
clique->separatorMarginal(EliminateDiscrete); clique->separatorMarginal(EliminateDiscrete);
DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
DOUBLES_EQUAL(joint_12_14, 0.1875, 1e-9);
DOUBLES_EQUAL(joint_8_12_14, 0.0375, 1e-9);
DOUBLES_EQUAL(joint_9_12_14, 0.15, 1e-9);
// check separator marginal P(S9), should be P(14) // check separator marginal P(S9), should be P(14)
clique = (*self.bayesTree)[9]; clique = (*self.bayesTree)[9];
DiscreteFactorGraph separatorMarginal9 = DiscreteFactorGraph separatorMarginal9 =

View File

@ -128,25 +128,81 @@ void GaussianMixture::print(const std::string &s,
}); });
} }
/* *******************************************************************************/ /* ************************************************************************* */
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { /// Return the DiscreteKey vector as a set.
// Functional which loops over all assignments and create a set of std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
// GaussianConditionals std::set<DiscreteKey> s;
auto pruner = [&decisionTree]( s.insert(dkeys.begin(), dkeys.end());
return s;
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value // typecast so we can use this to get probability value
DiscreteValues values(choices); DiscreteValues values(choices);
if (decisionTree(values) == 0.0) { // Case where the gaussian mixture has the same
// empty aka null pointer // discrete keys as the decision tree.
boost::shared_ptr<GaussianConditional> null; if (gaussianMixtureKeySet == decisionTreeKeySet) {
return null; if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else { } else {
return conditional; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
gaussianMixtureKeySet.begin(),
gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (decisionTree(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
} }
}; };
return pruner;
}
/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(decisionTree);
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;

View File

@ -69,6 +69,17 @@ class GTSAM_EXPORT GaussianMixture
*/ */
Sum asGaussianFactorGraphTree() const; Sum asGaussianFactorGraphTree() const;
/**
* @brief Helper function to get the pruner functor.
*
* @param decisionTree The pruned discrete probability decision tree.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree);
public: public:
/// @name Constructors /// @name Constructors
/// @{ /// @{

View File

@ -57,11 +57,12 @@ void GaussianMixtureFactor::print(const std::string &s,
[&](const GaussianFactor::shared_ptr &gf) -> std::string { [&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf) if (gf && !gf->empty()) {
gf->print("", formatter); gf->print("", formatter);
else return rd.str();
return {"nullptr"}; } else {
return rd.str(); return "nullptr";
}
}); });
std::cout << "}" << std::endl; std::cout << "}" << std::endl;
} }

View File

@ -22,14 +22,6 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> decisionTree;
@ -66,61 +58,18 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
if ((*discreteFactor)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per discreteFactor. // Bayes Net and prune them as per discreteFactor.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); HybridConditional::shared_ptr conditional = this->at(i);
GaussianMixture::shared_ptr gaussianMixture = if (conditional->isHybrid()) {
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner()); GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
if (gaussianMixture) { // Make a copy of the gaussian mixture and prune it!
// We may have mixtures with less discrete keys than discreteFactor so we auto prunedGaussianMixture =
// skip those since the label assignment does not exist. boost::make_shared<GaussianMixture>(*gaussianMixture);
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); prunedGaussianMixture->prune(*discreteFactor);
auto dfKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
if (gmKeySet != dfKeySet) {
// Add the gaussianMixture which doesn't have to be pruned.
prunedBayesNetFragment.push_back(
boost::make_shared<HybridConditional>(gaussianMixture));
continue;
}
// Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());
// Convert from boost::iterator_range to KeyVector
// so we can pass it to constructor.
KeyVector frontals(gaussianMixture->frontals().begin(),
gaussianMixture->frontals().end()),
parents(gaussianMixture->parents().begin(),
gaussianMixture->parents().end());
// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
frontals, parents, discreteKeys, prunedTree);
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(
@ -173,7 +122,7 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn; return gbn;
} }
/* *******************************************************************************/ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE // Solve for the MPE
DiscreteBayesNet discrete_bn; DiscreteBayesNet discrete_bn;
@ -190,7 +139,7 @@ HybridValues HybridBayesNet::optimize() const {
return HybridValues(mpe, gbn.optimize()); return HybridValues(mpe, gbn.optimize());
} }
/* *******************************************************************************/ /* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment); GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize(); return gbn.optimize();

View File

@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>( auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner()); this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_; decisionTree->root_ = prunedDecisionTree.root_;
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {
/// The discrete decision tree after pruning. /// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor; DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
const HybridBayesTree::sharedNode& parentClique) const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {} : prunedDecisionTree(prunedDecisionTree) {}
/** /**
* @brief A function used during tree traversal that operates on each node * @brief A function used during tree traversal that operates on each node
@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) { if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture(); auto gaussianMixture = conditional->asMixture();
// Check if the number of discrete keys match, gaussianMixture->prune(parentData.prunedDecisionTree);
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
} }
return parentData; return parentData;
} }
}; };
HybridPrunerData rootData(prunedDiscreteFactor, 0); HybridPrunerData rootData(prunedDecisionTree, 0);
{ {
treeTraversal::no_op visitorPost; treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP // Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -50,9 +50,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys) HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys), : Base(keys), isContinuous_(true), continuousKeys_(keys) {}
isContinuous_(true),
continuousKeys_(keys) {}
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys, HybridFactor::HybridFactor(const KeyVector &continuousKeys,
@ -101,7 +99,6 @@ void HybridFactor::print(const std::string &s,
if (d < discreteKeys_.size() - 1) { if (d < discreteKeys_.size() - 1) {
std::cout << " "; std::cout << " ";
} }
} }
std::cout << "]"; std::cout << "]";
} }

View File

@ -219,10 +219,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
result = EliminatePreferCholesky(graph, frontalKeys); result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) { if (keysOfEliminated.empty()) {
keysOfEliminated = // Initialize the keysOfEliminated to be the keys of the
result.first->keys(); // Initialize the keysOfEliminated to be the // eliminated GaussianConditional
keysOfEliminated = result.first->keys();
} }
// keysOfEliminated of the GaussianConditional
if (keysOfSeparator.empty()) { if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys(); keysOfSeparator = result.second->keys();
} }

View File

@ -59,8 +59,9 @@ void HybridGaussianISAM::updateInternal(
factors += newFactors; factors += newFactors;
// Add the orphaned subtrees // Add the orphaned subtrees
for (const sharedClique& orphan : *orphans) for (const sharedClique& orphan : *orphans) {
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan); factors += boost::make_shared<BayesTreeOrphanWrapper<Node>>(orphan);
}
// Get all the discrete keys from the factors // Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys(); KeySet allDiscrete = factors.discreteKeys();

View File

@ -99,7 +99,7 @@ void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_ cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl; << " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n"); isam_.print("HybridGaussianISAM:\n", keyFormatter);
linPoint_.print("Linearization Point:\n", keyFormatter); linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter); factors_.print("Nonlinear Graph:\n", keyFormatter);
} }

View File

@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); 2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); 3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); 5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(

View File

@ -12,7 +12,7 @@
/** /**
* @file testHybridNonlinearISAM.cpp * @file testHybridNonlinearISAM.cpp
* @brief Unit tests for nonlinear incremental inference * @brief Unit tests for nonlinear incremental inference
* @author Fan Jiang, Varun Agrawal, Frank Dellaert * @author Varun Agrawal, Fan Jiang, Frank Dellaert
* @date Jan 2021 * @date Jan 2021
*/ */
@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents()); 2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); 3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents()); 5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL( EXPECT_LONGS_EQUAL(
@ -432,9 +432,9 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// Don't run update now since we don't have discrete variables involved. // Don't run update now since we don't have discrete variables involved.
/*************** Run Round 2 ***************/
using PlanarMotionModel = BetweenFactor<Pose2>; using PlanarMotionModel = BetweenFactor<Pose2>;
/*************** Run Round 2 ***************/
// Add odometry factor with discrete modes. // Add odometry factor with discrete modes.
Pose2 odometry(1.0, 0.0, 0.0); Pose2 odometry(1.0, 0.0, 0.0);
KeyVector contKeys = {W(0), W(1)}; KeyVector contKeys = {W(0), W(1)};