diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index d0bf21047..227bb4da3 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -112,13 +112,12 @@ namespace gtsam { // } /** - * @brief Multiply all the `factors` and normalize the - * product to prevent underflow. + * @brief Multiply all the `factors`. * * @param factors The factors to multiply as a DiscreteFactorGraph. * @return DecisionTreeFactor */ - static DecisionTreeFactor ProductAndNormalize( + static DecisionTreeFactor DiscreteProduct( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors gttic(product); @@ -126,10 +125,10 @@ namespace gtsam { gttoc(product); // Max over all the potentials by pretending all keys are frontal: - auto normalization = product.max(product.size()); + auto denominator = product.max(product.size()); // Normalize the product factor to prevent underflow. - product = product / (*normalization); + product = product / (*denominator); return product; } @@ -139,7 +138,7 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = ProductAndNormalize(factors); + DecisionTreeFactor product = DiscreteProduct(factors); // max out frontals, this is the factor on the separator gttic(max); @@ -207,8 +206,7 @@ namespace gtsam { return dag.argmax(); } - DiscreteValues DiscreteFactorGraph::optimize( - const Ordering& ordering) const { + DiscreteValues DiscreteFactorGraph::optimize(const Ordering& ordering) const { gttic(DiscreteFactorGraph_optimize); DiscreteLookupDAG dag = maxProduct(ordering); return dag.argmax(); @@ -218,7 +216,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DecisionTreeFactor product = ProductAndNormalize(factors); + DecisionTreeFactor product = DiscreteProduct(factors); // sum out frontals, this is the factor on the separator gttic(sum); diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 22548da07..bf9662e34 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,41 +252,12 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - // Record key assignment and value pairs in pair_table. - // The assignments are stored in descending order of keys so that the order of - // the values matches what is expected by a DecisionTree. - // This is why we reverse the keys and then - // query for the key value/assignment. - DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); - std::vector> pair_table; + std::vector table; for (auto i = 0; i < sparse_table_.size(); i++) { - std::stringstream ss; - for (auto&& [key, _] : rdkeys) { - ss << keyValueForIndex(key, i); - } - // k will be in reverse key order already - uint64_t k; - ss >> k; - pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); + table.push_back(sparse_table_.coeff(i)); } - // Sort the pair_table (of assignment-value pairs) based on assignment so we - // get values in reverse key order. - std::sort( - pair_table.begin(), pair_table.end(), - [](const std::pair& a, - const std::pair& b) { return a.first < b.first; }); - - // Create the table vector by extracting the values from pair_table. - // The pair_table has already been sorted in the desired order, - // so the values will be in descending key order. - std::vector table; - std::for_each(pair_table.begin(), pair_table.end(), - [&table](const std::pair& pair) { - table.push_back(pair.second); - }); - - AlgebraicDecisionTree tree(rdkeys, table); + AlgebraicDecisionTree tree(dkeys, table); DecisionTreeFactor f(dkeys, tree); return f; } diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 0d71c12ba..cbcf5234e 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) { *std::dynamic_pointer_cast(newFactorPtr); // Normalize newFactor by max for comparison with expected - auto normalization = newFactor.max(newFactor.size()); + auto normalizer = newFactor.max(newFactor.size()); - newFactor = newFactor / *normalization; + newFactor = newFactor / *normalizer; // Check Conditional CHECK(conditional); @@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) { CHECK(&newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); // Normalize by max. - normalization = expectedFactor.max(expectedFactor.size()); - // Ensure normalization is correct. - expectedFactor = expectedFactor / *normalization; + normalizer = expectedFactor.max(expectedFactor.size()); + // Ensure normalizer is correct. + expectedFactor = expectedFactor / *normalizer; EXPECT(assert_equal(expectedFactor, newFactor)); // Test using elimination tree diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8832cbb34..b9051554a 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -282,7 +282,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } else if (auto hc = dynamic_pointer_cast(f)) { auto dc = hc->asDiscrete(); if (!dc) throwRuntimeError("discreteElimination", dc); - dfg.push_back(hc->asDiscrete()); + dfg.push_back(dc); } else { throwRuntimeError("discreteElimination", f); }