diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index ad9a943e3..aa7f1d391 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -47,15 +47,6 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, const DecisionTreeFactor& f) : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ************************************************************************** */ -DiscreteConditional::DiscreteConditional(size_t nrFrontals, - const DecisionTreeFactor& f, - const Ordering& orderedKeys) - : BaseFactor(f), BaseConditional(nrFrontals) { - keys_.clear(); - keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); -} - /* ************************************************************************** */ DiscreteConditional::DiscreteConditional(size_t nrFrontals, const DiscreteKeys& keys, diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 89b8d0d56..98edcb8c9 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -56,17 +56,6 @@ class GTSAM_EXPORT DiscreteConditional /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); - /** - * @brief Construct from DecisionTreeFactor, - * taking the first `nrFrontals` from `orderedKeys`. - * - * @param nrFrontals The number of frontal variables. - * @param f The DecisionTreeFactor to construct from. - * @param orderedKeys Ordered list of keys involved in the conditional. - */ - DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f, - const Ordering& orderedKeys); - /** * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first * `nFrontals` keys as frontals, in the order given. diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a833e1c5e..bf9662e34 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,15 +252,6 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - // If no keys, then return empty DecisionTreeFactor - if (dkeys.size() == 0) { - AlgebraicDecisionTree tree; - if (sparse_table_.size() != 0) { - tree = AlgebraicDecisionTree(sparse_table_.coeff(0)); - } - return DecisionTreeFactor(dkeys, tree); - } - std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 502a52742..0213cd64b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -256,7 +256,7 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( } /* ************************************************************************ */ -TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) { +TableFactor TableProduct(const DiscreteFactorGraph &factors) { // PRODUCT: multiply all factors #if GTSAM_HYBRID_TIMING gttic_(DiscreteProduct); @@ -279,14 +279,13 @@ TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) { gttoc_(DiscreteProduct); #endif - // Max over all the potentials by pretending all keys are frontal: - auto normalizer = product.max(product.size()); - #if GTSAM_HYBRID_TIMING gttic_(DiscreteNormalize); #endif + // Max over all the potentials by pretending all keys are frontal: + auto denominator = product.max(product.size()); // Normalize the product factor to prevent underflow. - product = product / (*normalizer); + product = product / (*denominator); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteNormalize); #endif @@ -343,41 +342,40 @@ discreteElimination(const HybridGaussianFactorGraph &factors, #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif - /**** NOTE: This does sum-product. ****/ - // Get product factor - TableFactor product = TableProductAndNormalize(dfg); + // Check if separator is empty + Ordering allKeys(dfg.keyVector()); + Ordering separator; + std::set_difference(allKeys.begin(), allKeys.end(), frontalKeys.begin(), + frontalKeys.end(), + std::inserter(separator, separator.begin())); + + // If the separator is empty, we have a clique of all the discrete variables + // so we can use the TableFactor for efficiency. + if (separator.size() == 0) { + // Get product factor + TableFactor product = TableProduct(dfg); #if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteSum); + gttic_(EliminateDiscreteFormDiscreteConditional); #endif - // All the discrete variables should form a single clique, - // so we can sum out on all the variables as frontals. - // This should give an empty separator. - TableFactor::shared_ptr sum = product.sum(frontalKeys); + auto conditional = std::make_shared( + frontalKeys.size(), product.toDecisionTreeFactor()); #if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteSum); + gttoc_(EliminateDiscreteFormDiscreteConditional); #endif - // Ordering keys for the conditional so that frontalKeys are really in front - Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + TableFactor::shared_ptr sum = product.sum(frontalKeys); -#if GTSAM_HYBRID_TIMING - gttic_(EliminateDiscreteFormDiscreteConditional); -#endif - // Finally, get the conditional - auto conditional = - std::make_shared(product, *sum, orderedKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscreteFormDiscreteConditional); -#endif + return {std::make_shared(conditional), sum}; + } else { + // Perform sum-product. + auto result = EliminateDiscrete(dfg, frontalKeys); + return {std::make_shared(result.first), result.second}; + } #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif - - return {std::make_shared(conditional), sum}; } /* ************************************************************************ */