diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2b23ed4db..d2ea3d5ef 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -348,64 +348,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // When the number of assignments is large we may encounter stack overflows. // However this is also the case with iSAM2, so no pressure :) - // PREPROCESS: Identify the nature of the current elimination - - // TODO(dellaert): just check the factors: + // Check the factors: // 1. if all factors are discrete, then we can do discrete elimination: // 2. if all factors are continuous, then we can do continuous elimination: // 3. if not, we do hybrid elimination: - // First, identify the separator keys, i.e. all keys that are not frontal. - KeySet separatorKeys; + bool only_discrete = true, only_continuous = true; for (auto &&factor : factors) { - separatorKeys.insert(factor->begin(), factor->end()); - } - // remove frontals from separator - for (auto &k : frontalKeys) { - separatorKeys.erase(k); - } - - // Build a map from keys to DiscreteKeys - auto mapFromKeyToDiscreteKey = factors.discreteKeyMap(); - - // Fill in discrete frontals and continuous frontals. - std::set discreteFrontals; - KeySet continuousFrontals; - for (auto &k : frontalKeys) { - if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { - discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k)); - } else { - continuousFrontals.insert(k); + if (auto hybrid_factor = std::dynamic_pointer_cast(factor)) { + if (hybrid_factor->isDiscrete()) { + only_continuous = false; + } else if (hybrid_factor->isContinuous()) { + only_discrete = false; + } else if (hybrid_factor->isHybrid()) { + only_continuous = false; + only_discrete = false; + } + } else if (auto cont_factor = + std::dynamic_pointer_cast(factor)) { + only_discrete = false; + } else if (auto discrete_factor = + std::dynamic_pointer_cast(factor)) { + only_continuous = false; } } - // Fill in discrete discrete separator keys and continuous separator keys. - std::set discreteSeparatorSet; - KeyVector continuousSeparator; - for (auto &k : separatorKeys) { - if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { - discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); - } else { - continuousSeparator.push_back(k); - } - } - - // Check if we have any continuous keys: - const bool discrete_only = - continuousFrontals.empty() && continuousSeparator.empty(); - // NOTE: We should really defer the product here because of pruning - if (discrete_only) { + if (only_discrete) { // Case 1: we are only dealing with discrete return discreteElimination(factors, frontalKeys); - } else if (mapFromKeyToDiscreteKey.empty()) { + } else if (only_continuous) { // Case 2: we are only dealing with continuous return continuousElimination(factors, frontalKeys); } else { // Case 3: We are now in the hybrid land! + KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); + + // Find all the keys in the set of continuous keys + // which are not in the frontal keys. This is our continuous separator. + KeyVector continuousSeparator; + auto continuousKeySet = factors.continuousKeySet(); + std::set_difference( + continuousKeySet.begin(), continuousKeySet.end(), + frontalKeysSet.begin(), frontalKeysSet.end(), + std::inserter(continuousSeparator, continuousSeparator.begin())); + + // Similarly for the discrete separator. + KeySet discreteSeparatorSet; + std::set discreteSeparator; + auto discreteKeySet = factors.discreteKeySet(); + std::set_difference( + discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(), + frontalKeysSet.end(), + std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin())); + // Convert from set of keys to set of DiscreteKeys + auto discreteKeyMap = factors.discreteKeyMap(); + for (auto key : discreteSeparatorSet) { + discreteSeparator.insert(discreteKeyMap.at(key)); + } + return hybridElimination(factors, frontalKeys, continuousSeparator, - discreteSeparatorSet); + discreteSeparator); } }