Merge pull request #46 from varunagrawal/feature/eliminate-hybrid

Break up EliminateHybrid into smaller functions
release/4.3a0
Varun Agrawal 2022-06-01 23:05:52 -04:00 committed by GitHub
commit 1a8fa235c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 154 additions and 139 deletions

View File

@ -55,14 +55,6 @@ namespace gtsam {
template class EliminateableFactorGraph<HybridFactorGraph>; template class EliminateableFactorGraph<HybridFactorGraph>;
static std::string BLACK_BOLD = "\033[1;30m";
static std::string RED_BOLD = "\033[1;31m";
static std::string GREEN = "\033[0;32m";
static std::string GREEN_BOLD = "\033[1;32m";
static std::string RESET = "\033[0m";
constexpr bool DEBUG = false;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
@ -85,98 +77,9 @@ static GaussianMixtureFactor::Sum &addGaussian(
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { continuousElimination(const HybridFactorGraph &factors,
// NOTE(fan): Because we are in the Conditional Gaussian regime there are only const Ordering &frontalKeys) {
// a few cases:
// continuous variable, we make a GM if there are hybrid factors;
// continuous variable, we make a GF if there are no hybrid factors;
// discrete variable, no continuous factor is allowed (escapes CG regime), so
// we panic, if discrete only we do the discrete elimination.
// However it is not that simple. During elimination it is possible that the
// multifrontal needs to eliminate an ordering that contains both Gaussian and
// hybrid variables, for example x1, c1. In this scenario, we will have a
// density P(x1, c1) that is a CLG P(x1|c1)P(c1) (see Murphy02)
// The issue here is that, how can we know which variable is discrete if we
// unify Values? Obviously we can tell using the factors, but is that fast?
// In the case of multifrontal, we will need to use a constrained ordering
// so that the discrete parts will be guaranteed to be eliminated last!
// Because of all these reasons, we need to think very carefully about how to
// implement the hybrid factors so that we do not get poor performance.
//
// The first thing is how to represent the GaussianMixtureConditional. A very
// possible scenario is that the incoming factors will have different levels
// of discrete keys. For example, imagine we are going to eliminate the
// fragment:
// $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. Now we will
// need to know how to retrieve the corresponding continuous densities for the
// assi- -gnment (c1,c2,c3) (OR (c2,c3,c1)! note there is NO defined order!).
// And we also need to consider when there is pruning. Two mixture factors
// could have different pruning patterns-one could have (c1=0,c2=1) pruned,
// and another could have (c2=0,c3=1) pruned, and this creates a big problem
// in how to identify the intersection of non-pruned branches.
// One possible approach is first building the collection of all discrete
// keys. After that we enumerate the space of all key combinations *lazily* so
// that the exploration branch terminates whenever an assignment yields NULL
// in any of the hybrid factors.
// When the number of assignments is large we may encounter stack overflows.
// However this is also the case with iSAM2, so no pressure :)
// PREPROCESS: Identify the nature of the current elimination
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
std::set<DiscreteKey> discreteSeparatorSet;
std::set<DiscreteKey> discreteFrontals;
KeySet separatorKeys;
KeySet allContinuousKeys;
KeySet continuousFrontals;
KeySet continuousSeparator;
// This initializes separatorKeys and mapFromKeyToDiscreteKey
for (auto &&factor : factors) {
separatorKeys.insert(factor->begin(), factor->end());
if (!factor->isContinuous()) {
for (auto &k : factor->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k;
}
}
}
// remove frontals from separator
for (auto &k : frontalKeys) {
separatorKeys.erase(k);
}
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : frontalKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousFrontals.insert(k);
allContinuousKeys.insert(k);
}
}
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousSeparator.insert(k);
allContinuousKeys.insert(k);
}
}
// NOTE: We should really defer the product here because of pruning
// Case 1: we are only dealing with continuous
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
GaussianFactorGraph gfg; GaussianFactorGraph gfg;
for (auto &fp : factors) { for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp); auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
@ -198,8 +101,10 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
boost::make_shared<HybridGaussianFactor>(result.second)); boost::make_shared<HybridGaussianFactor>(result.second));
} }
// Case 2: we are only dealing with discrete /* ************************************************************************ */
if (allContinuousKeys.empty()) { std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
discreteElimination(const HybridFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &fp : factors) { for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp); auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
@ -222,9 +127,13 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
boost::make_shared<HybridDiscreteFactor>(result.second)); boost::make_shared<HybridDiscreteFactor>(result.second));
} }
// Case 3: We are now in the hybrid land! /* ************************************************************************ */
// NOTE: since we use the special JunctionTree, only possiblity is cont. std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
// conditioned on disc. hybridElimination(const HybridFactorGraph &factors, const Ordering &frontalKeys,
const KeySet &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
@ -232,7 +141,6 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
gttic(sum); gttic(sum);
GaussianMixtureFactor::Sum sum; GaussianMixtureFactor::Sum sum;
std::vector<GaussianFactor::shared_ptr> deferredFactors; std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) { for (auto &f : factors) {
@ -296,12 +204,11 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
auto pair = unzip(eliminationResults); auto pair = unzip(eliminationResults);
const GaussianMixtureConditional::Conditionals &conditionals = pair.first;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second; const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixtureConditional from the conditionals // Create the GaussianMixtureConditional from the conditionals
auto conditional = boost::make_shared<GaussianMixtureConditional>( auto conditional = boost::make_shared<GaussianMixtureConditional>(
frontalKeys, keysOfSeparator, discreteSeparator, conditionals); frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
// If there are no more continuous parents, then we should create here a // If there are no more continuous parents, then we should create here a
// DiscreteFactor, with the error for each discrete choice. // DiscreteFactor, with the error for each discrete choice.
@ -326,6 +233,114 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
return {boost::make_shared<HybridConditional>(conditional), factor}; return {boost::make_shared<HybridConditional>(conditional), factor};
} }
} }
/* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only
// a few cases:
// 1. continuous variable, make a Gaussian Mixture if there are hybrid
// factors;
// 2. continuous variable, we make a Gaussian Factor if there are no hybrid
// factors;
// 3. discrete variable, no continuous factor is allowed
// (escapes Conditional Gaussian regime), if discrete only we do the discrete
// elimination.
// However it is not that simple. During elimination it is possible that the
// multifrontal needs to eliminate an ordering that contains both Gaussian and
// hybrid variables, for example x1, c1.
// In this scenario, we will have a density P(x1, c1) that is a Conditional
// Linear Gaussian P(x1|c1)P(c1) (see Murphy02).
// The issue here is that, how can we know which variable is discrete if we
// unify Values? Obviously we can tell using the factors, but is that fast?
// In the case of multifrontal, we will need to use a constrained ordering
// so that the discrete parts will be guaranteed to be eliminated last!
// Because of all these reasons, we carefully consider how to
// implement the hybrid factors so that we do not get poor performance.
// The first thing is how to represent the GaussianMixtureConditional.
// A very possible scenario is that the incoming factors will have different
// levels of discrete keys. For example, imagine we are going to eliminate the
// fragment: $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid.
// Now we will need to know how to retrieve the corresponding continuous
// densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO
// defined order!). We also need to consider when there is pruning. Two
// mixture factors could have different pruning patterns - one could have
// (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this
// creates a big problem in how to identify the intersection of non-pruned
// branches.
// Our approach is first building the collection of all discrete keys. After
// that we enumerate the space of all key combinations *lazily* so that the
// exploration branch terminates whenever an assignment yields NULL in any of
// the hybrid factors.
// When the number of assignments is large we may encounter stack overflows.
// However this is also the case with iSAM2, so no pressure :)
// PREPROCESS: Identify the nature of the current elimination
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
std::set<DiscreteKey> discreteSeparatorSet;
std::set<DiscreteKey> discreteFrontals;
KeySet separatorKeys;
KeySet allContinuousKeys;
KeySet continuousFrontals;
KeySet continuousSeparator;
// This initializes separatorKeys and mapFromKeyToDiscreteKey
for (auto &&factor : factors) {
separatorKeys.insert(factor->begin(), factor->end());
if (!factor->isContinuous()) {
for (auto &k : factor->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k;
}
}
}
// remove frontals from separator
for (auto &k : frontalKeys) {
separatorKeys.erase(k);
}
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : frontalKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousFrontals.insert(k);
allContinuousKeys.insert(k);
}
}
// Fill in discrete frontals and continuous frontals for the end result
for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else {
continuousSeparator.insert(k);
allContinuousKeys.insert(k);
}
}
// NOTE: We should really defer the product here because of pruning
// Case 1: we are only dealing with continuous
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
return continuousElimination(factors, frontalKeys);
}
// Case 2: we are only dealing with discrete
if (allContinuousKeys.empty()) {
return discreteElimination(factors, frontalKeys);
}
// Case 3: We are now in the hybrid land!
return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet);
}
/* ************************************************************************ */ /* ************************************************************************ */
void HybridFactorGraph::add(JacobianFactor &&factor) { void HybridFactorGraph::add(JacobianFactor &&factor) {