From a8e24efdeca680f391098979ad2402dd3918153f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 13 Dec 2024 09:34:01 -0500 Subject: [PATCH] update ComputeLeafOrdering to give a correct vector of values --- gtsam/discrete/TableFactor.cpp | 72 +++++++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 8e185eb3b..a2d68853e 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -64,27 +64,69 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, /** * @brief Compute the correct ordering of the leaves in the decision tree. * - * This is done by first taking all the values which have modulo 0 value with - * the cardinality of the innermost key `n`, and we go up to modulo n. - * * @param dt The DecisionTree - * @return std::vector + * @return Eigen::SparseVector */ -std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, - const DecisionTreeFactor& dt) { - std::vector probs = dt.probabilities(); - std::vector ordered; +static Eigen::SparseVector ComputeLeafOrdering( + const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) { + // SparseVector needs to know the maximum possible index, + // so we compute the product of cardinalities. + size_t prod_cardinality = 1; + for (auto&& [_, c] : dt.cardinalities()) { + prod_cardinality *= c; + } + Eigen::SparseVector sparse_table(prod_cardinality); + size_t nrValues = 0; + dt.visit([&nrValues](double x) { + if (x > 0) nrValues += 1; + }); + sparse_table.reserve(nrValues); - size_t n = dkeys[0].second; + std::set allKeys(dt.keys().begin(), dt.keys().end()); - for (size_t k = 0; k < n; ++k) { - for (size_t idx = 0; idx < probs.size(); ++idx) { - if (idx % n == k) { - ordered.push_back(probs[idx]); + auto op = [&](const Assignment& assignment, double p) { + if (p > 0) { + // Get all the keys involved in this assignment + std::set assignment_keys; + for (auto&& [k, _] : assignment) { + assignment_keys.insert(k); + } + + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); + + // Generate all assignments using the missing keys + DiscreteKeys extras; + for (auto&& key : diff) { + extras.push_back({key, dt.cardinality(key)}); + } + auto&& extra_assignments = DiscreteValues::CartesianProduct(extras); + + for (auto&& extra : extra_assignments) { + // Create new assignment using the extra assignment + DiscreteValues updated_assignment(assignment); + updated_assignment.insert(extra); + + // Generate index and add to the sparse vector. + Eigen::Index idx = 0; + size_t prev_cardinality = 1; + // We go in reverse since a DecisionTree has the highest label first + for (auto&& it = updated_assignment.rbegin(); + it != updated_assignment.rend(); it++) { + idx += prev_cardinality * it->second; + prev_cardinality *= dt.cardinality(it->first); + } + sparse_table.coeffRef(idx) = p; } } - } - return ordered; + }; + + dt.visitWith(op); + + return sparse_table; } /* ************************************************************************ */