diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a2d68853e..de1e1f867 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -62,40 +62,55 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} /** - * @brief Compute the correct ordering of the leaves in the decision tree. + * @brief Compute the indexing of the leaves in the decision tree based on the + * assignment and add the (index, leaf) pair to a SparseVector. + * + * We visit each leaf in the tree, and using the cardinalities of the keys, + * compute the correct index to add the leaf to a SparseVector which + * is then used to create the TableFactor. * * @param dt The DecisionTree * @return Eigen::SparseVector */ -static Eigen::SparseVector ComputeLeafOrdering( +static Eigen::SparseVector ComputeSparseTable( const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) { // SparseVector needs to know the maximum possible index, // so we compute the product of cardinalities. - size_t prod_cardinality = 1; + size_t cardinalityProduct = 1; for (auto&& [_, c] : dt.cardinalities()) { - prod_cardinality *= c; + cardinalityProduct *= c; } - Eigen::SparseVector sparse_table(prod_cardinality); + Eigen::SparseVector sparseTable(cardinalityProduct); size_t nrValues = 0; dt.visit([&nrValues](double x) { if (x > 0) nrValues += 1; }); - sparse_table.reserve(nrValues); + sparseTable.reserve(nrValues); std::set allKeys(dt.keys().begin(), dt.keys().end()); + /** + * @brief Functor which is called by the DecisionTree for each leaf. + * For each leaf value, we use the corresponding assignment to compute a + * corresponding index into a SparseVector. We then populate sparseTable with + * the value at the computed index. + * + * Takes advantage of the sparsity of the DecisionTree to be efficient. When + * merged branches are encountered, we enumerate over the missing keys. + * + */ auto op = [&](const Assignment& assignment, double p) { if (p > 0) { // Get all the keys involved in this assignment - std::set assignment_keys; + std::set assignmentKeys; for (auto&& [k, _] : assignment) { - assignment_keys.insert(k); + assignmentKeys.insert(k); } // Find the keys missing in the assignment std::vector diff; std::set_difference(allKeys.begin(), allKeys.end(), - assignment_keys.begin(), assignment_keys.end(), + assignmentKeys.begin(), assignmentKeys.end(), std::back_inserter(diff)); // Generate all assignments using the missing keys @@ -103,41 +118,43 @@ static Eigen::SparseVector ComputeLeafOrdering( for (auto&& key : diff) { extras.push_back({key, dt.cardinality(key)}); } - auto&& extra_assignments = DiscreteValues::CartesianProduct(extras); + auto&& extraAssignments = DiscreteValues::CartesianProduct(extras); - for (auto&& extra : extra_assignments) { + for (auto&& extra : extraAssignments) { // Create new assignment using the extra assignment - DiscreteValues updated_assignment(assignment); - updated_assignment.insert(extra); + DiscreteValues updatedAssignment(assignment); + updatedAssignment.insert(extra); // Generate index and add to the sparse vector. Eigen::Index idx = 0; - size_t prev_cardinality = 1; + size_t previousCardinality = 1; // We go in reverse since a DecisionTree has the highest label first - for (auto&& it = updated_assignment.rbegin(); - it != updated_assignment.rend(); it++) { - idx += prev_cardinality * it->second; - prev_cardinality *= dt.cardinality(it->first); + for (auto&& it = updatedAssignment.rbegin(); + it != updatedAssignment.rend(); it++) { + idx += previousCardinality * it->second; + previousCardinality *= dt.cardinality(it->first); } - sparse_table.coeffRef(idx) = p; + sparseTable.coeffRef(idx) = p; } } }; + // Visit each leaf in `dt` to get the Assignment and leaf value + // to populate the sparseTable. dt.visitWith(op); - return sparse_table; + return sparseTable; } /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const DecisionTreeFactor& dtf) - : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + : TableFactor(dkeys, ComputeSparseTable(dkeys, dtf)) {} /* ************************************************************************ */ TableFactor::TableFactor(const DecisionTreeFactor& dtf) : TableFactor(dtf.discreteKeys(), - ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {} + ComputeSparseTable(dtf.discreteKeys(), dtf)) {} /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c)