diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 8e185eb3b..de1e1f867 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -62,40 +62,99 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} /** - * @brief Compute the correct ordering of the leaves in the decision tree. + * @brief Compute the indexing of the leaves in the decision tree based on the + * assignment and add the (index, leaf) pair to a SparseVector. * - * This is done by first taking all the values which have modulo 0 value with - * the cardinality of the innermost key `n`, and we go up to modulo n. + * We visit each leaf in the tree, and using the cardinalities of the keys, + * compute the correct index to add the leaf to a SparseVector which + * is then used to create the TableFactor. * * @param dt The DecisionTree - * @return std::vector + * @return Eigen::SparseVector */ -std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, - const DecisionTreeFactor& dt) { - std::vector probs = dt.probabilities(); - std::vector ordered; +static Eigen::SparseVector ComputeSparseTable( + const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) { + // SparseVector needs to know the maximum possible index, + // so we compute the product of cardinalities. + size_t cardinalityProduct = 1; + for (auto&& [_, c] : dt.cardinalities()) { + cardinalityProduct *= c; + } + Eigen::SparseVector sparseTable(cardinalityProduct); + size_t nrValues = 0; + dt.visit([&nrValues](double x) { + if (x > 0) nrValues += 1; + }); + sparseTable.reserve(nrValues); - size_t n = dkeys[0].second; + std::set allKeys(dt.keys().begin(), dt.keys().end()); - for (size_t k = 0; k < n; ++k) { - for (size_t idx = 0; idx < probs.size(); ++idx) { - if (idx % n == k) { - ordered.push_back(probs[idx]); + /** + * @brief Functor which is called by the DecisionTree for each leaf. + * For each leaf value, we use the corresponding assignment to compute a + * corresponding index into a SparseVector. We then populate sparseTable with + * the value at the computed index. + * + * Takes advantage of the sparsity of the DecisionTree to be efficient. When + * merged branches are encountered, we enumerate over the missing keys. + * + */ + auto op = [&](const Assignment& assignment, double p) { + if (p > 0) { + // Get all the keys involved in this assignment + std::set assignmentKeys; + for (auto&& [k, _] : assignment) { + assignmentKeys.insert(k); + } + + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignmentKeys.begin(), assignmentKeys.end(), + std::back_inserter(diff)); + + // Generate all assignments using the missing keys + DiscreteKeys extras; + for (auto&& key : diff) { + extras.push_back({key, dt.cardinality(key)}); + } + auto&& extraAssignments = DiscreteValues::CartesianProduct(extras); + + for (auto&& extra : extraAssignments) { + // Create new assignment using the extra assignment + DiscreteValues updatedAssignment(assignment); + updatedAssignment.insert(extra); + + // Generate index and add to the sparse vector. + Eigen::Index idx = 0; + size_t previousCardinality = 1; + // We go in reverse since a DecisionTree has the highest label first + for (auto&& it = updatedAssignment.rbegin(); + it != updatedAssignment.rend(); it++) { + idx += previousCardinality * it->second; + previousCardinality *= dt.cardinality(it->first); + } + sparseTable.coeffRef(idx) = p; } } - } - return ordered; + }; + + // Visit each leaf in `dt` to get the Assignment and leaf value + // to populate the sparseTable. + dt.visitWith(op); + + return sparseTable; } /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const DecisionTreeFactor& dtf) - : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + : TableFactor(dkeys, ComputeSparseTable(dkeys, dtf)) {} /* ************************************************************************ */ TableFactor::TableFactor(const DecisionTreeFactor& dtf) : TableFactor(dtf.discreteKeys(), - ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {} + ComputeSparseTable(dtf.discreteKeys(), dtf)) {} /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 212067cb3..a455faaaa 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -147,6 +147,34 @@ TEST(TableFactor, constructors) { EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9)); } +/* ************************************************************************* */ +// Check conversion from DecisionTreeFactor. +TEST(TableFactor, Conversion) { + /* This is the DecisionTree we are using + Choice(m2) + 0 Choice(m1) + 0 0 Leaf 0 + 0 1 Choice(m0) + 0 1 0 Leaf 0 + 0 1 1 Leaf 0.14649446 // 3 + 1 Choice(m1) + 1 0 Choice(m0) + 1 0 0 Leaf 0 + 1 0 1 Leaf 0.14648756 // 5 + 1 1 Choice(m0) + 1 1 0 Leaf 0.14649446 // 6 + 1 1 1 Leaf 0.23918345 // 7 + */ + DiscreteKeys dkeys = {{0, 2}, {1, 2}, {2, 2}}; + DecisionTreeFactor dtf( + dkeys, std::vector{0, 0, 0, 0.14649446, 0, 0.14648756, 0.14649446, + 0.23918345}); + + TableFactor tf(dtf.discreteKeys(), dtf); + + EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor())); +} + /* ************************************************************************* */ // Check multiplication between two TableFactors. TEST(TableFactor, multiplication) {