major improvement: continuousSeparator no longer needed

release/4.3a0
Frank Dellaert 2024-09-27 07:54:47 -07:00
parent 7d51e1cdb4
commit bc25fcea4d
1 changed files with 10 additions and 16 deletions

View File

@ -369,7 +369,6 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors, hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys, const Ordering &frontalKeys,
const KeyVector &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) { const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree, // NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete. // only possibility is continuous conditioned on discrete.
@ -386,13 +385,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
factorGraphTree = removeEmpty(factorGraphTree); factorGraphTree = removeEmpty(factorGraphTree);
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
bool someContinuousLeft = false;
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result { auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
if (graph.empty()) { if (graph.empty()) {
return {nullptr, nullptr}; return {nullptr, nullptr};
} }
// Expensive elimination of product factor.
auto result = EliminatePreferCholesky(graph, frontalKeys); auto result = EliminatePreferCholesky(graph, frontalKeys);
// Record whether there any continuous variables left
someContinuousLeft |= !result.second->empty();
return result; return result;
}; };
@ -403,9 +407,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// error for each discrete choice. Otherwise, create a HybridGaussianFactor // error for each discrete choice. Otherwise, create a HybridGaussianFactor
// on the separator, taking care to correct for conditional constants. // on the separator, taking care to correct for conditional constants.
auto newFactor = auto newFactor =
continuousSeparator.empty() someContinuousLeft
? createDiscreteFactor(eliminationResults, discreteSeparator) ? createHybridGaussianFactor(eliminationResults, discreteSeparator)
: createHybridGaussianFactor(eliminationResults, discreteSeparator); : createDiscreteFactor(eliminationResults, discreteSeparator);
// Create the HybridGaussianConditional from the conditionals // Create the HybridGaussianConditional from the conditionals
HybridGaussianConditional::Conditionals conditionals( HybridGaussianConditional::Conditionals conditionals(
@ -514,22 +518,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Case 3: We are now in the hybrid land! // Case 3: We are now in the hybrid land!
KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
// Find all the keys in the set of continuous keys // Find all discrete keys.
// which are not in the frontal keys. This is our continuous separator.
KeyVector continuousSeparator;
auto continuousKeySet = factors.continuousKeySet();
std::set_difference(
continuousKeySet.begin(), continuousKeySet.end(),
frontalKeysSet.begin(), frontalKeysSet.end(),
std::inserter(continuousSeparator, continuousSeparator.begin()));
// Similarly for the discrete separator.
// Since we eliminate all continuous variables first, // Since we eliminate all continuous variables first,
// the discrete separator will be *all* the discrete keys. // the discrete separator will be *all* the discrete keys.
std::set<DiscreteKey> discreteSeparator = factors.discreteKeys(); std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();
return hybridElimination(factors, frontalKeys, continuousSeparator, return hybridElimination(factors, frontalKeys, discreteSeparator);
discreteSeparator);
} }
} }