address review comments

release/4.3a0
Varun Agrawal 2024-12-13 13:58:19 -05:00
parent 7d389a5300
commit 9830981351
1 changed files with 39 additions and 22 deletions

View File

@ -62,40 +62,55 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
/** /**
* @brief Compute the correct ordering of the leaves in the decision tree. * @brief Compute the indexing of the leaves in the decision tree based on the
* assignment and add the (index, leaf) pair to a SparseVector.
*
* We visit each leaf in the tree, and using the cardinalities of the keys,
* compute the correct index to add the leaf to a SparseVector which
* is then used to create the TableFactor.
* *
* @param dt The DecisionTree * @param dt The DecisionTree
* @return Eigen::SparseVector<double> * @return Eigen::SparseVector<double>
*/ */
static Eigen::SparseVector<double> ComputeLeafOrdering( static Eigen::SparseVector<double> ComputeSparseTable(
const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) { const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) {
// SparseVector needs to know the maximum possible index, // SparseVector needs to know the maximum possible index,
// so we compute the product of cardinalities. // so we compute the product of cardinalities.
size_t prod_cardinality = 1; size_t cardinalityProduct = 1;
for (auto&& [_, c] : dt.cardinalities()) { for (auto&& [_, c] : dt.cardinalities()) {
prod_cardinality *= c; cardinalityProduct *= c;
} }
Eigen::SparseVector<double> sparse_table(prod_cardinality); Eigen::SparseVector<double> sparseTable(cardinalityProduct);
size_t nrValues = 0; size_t nrValues = 0;
dt.visit([&nrValues](double x) { dt.visit([&nrValues](double x) {
if (x > 0) nrValues += 1; if (x > 0) nrValues += 1;
}); });
sparse_table.reserve(nrValues); sparseTable.reserve(nrValues);
std::set<Key> allKeys(dt.keys().begin(), dt.keys().end()); std::set<Key> allKeys(dt.keys().begin(), dt.keys().end());
/**
* @brief Functor which is called by the DecisionTree for each leaf.
* For each leaf value, we use the corresponding assignment to compute a
* corresponding index into a SparseVector. We then populate sparseTable with
* the value at the computed index.
*
* Takes advantage of the sparsity of the DecisionTree to be efficient. When
* merged branches are encountered, we enumerate over the missing keys.
*
*/
auto op = [&](const Assignment<Key>& assignment, double p) { auto op = [&](const Assignment<Key>& assignment, double p) {
if (p > 0) { if (p > 0) {
// Get all the keys involved in this assignment // Get all the keys involved in this assignment
std::set<Key> assignment_keys; std::set<Key> assignmentKeys;
for (auto&& [k, _] : assignment) { for (auto&& [k, _] : assignment) {
assignment_keys.insert(k); assignmentKeys.insert(k);
} }
// Find the keys missing in the assignment // Find the keys missing in the assignment
std::vector<Key> diff; std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(), std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(), assignmentKeys.begin(), assignmentKeys.end(),
std::back_inserter(diff)); std::back_inserter(diff));
// Generate all assignments using the missing keys // Generate all assignments using the missing keys
@ -103,41 +118,43 @@ static Eigen::SparseVector<double> ComputeLeafOrdering(
for (auto&& key : diff) { for (auto&& key : diff) {
extras.push_back({key, dt.cardinality(key)}); extras.push_back({key, dt.cardinality(key)});
} }
auto&& extra_assignments = DiscreteValues::CartesianProduct(extras); auto&& extraAssignments = DiscreteValues::CartesianProduct(extras);
for (auto&& extra : extra_assignments) { for (auto&& extra : extraAssignments) {
// Create new assignment using the extra assignment // Create new assignment using the extra assignment
DiscreteValues updated_assignment(assignment); DiscreteValues updatedAssignment(assignment);
updated_assignment.insert(extra); updatedAssignment.insert(extra);
// Generate index and add to the sparse vector. // Generate index and add to the sparse vector.
Eigen::Index idx = 0; Eigen::Index idx = 0;
size_t prev_cardinality = 1; size_t previousCardinality = 1;
// We go in reverse since a DecisionTree has the highest label first // We go in reverse since a DecisionTree has the highest label first
for (auto&& it = updated_assignment.rbegin(); for (auto&& it = updatedAssignment.rbegin();
it != updated_assignment.rend(); it++) { it != updatedAssignment.rend(); it++) {
idx += prev_cardinality * it->second; idx += previousCardinality * it->second;
prev_cardinality *= dt.cardinality(it->first); previousCardinality *= dt.cardinality(it->first);
} }
sparse_table.coeffRef(idx) = p; sparseTable.coeffRef(idx) = p;
} }
} }
}; };
// Visit each leaf in `dt` to get the Assignment and leaf value
// to populate the sparseTable.
dt.visitWith(op); dt.visitWith(op);
return sparse_table; return sparseTable;
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys, TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dtf) const DecisionTreeFactor& dtf)
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} : TableFactor(dkeys, ComputeSparseTable(dkeys, dtf)) {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DecisionTreeFactor& dtf) TableFactor::TableFactor(const DecisionTreeFactor& dtf)
: TableFactor(dtf.discreteKeys(), : TableFactor(dtf.discreteKeys(),
ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {} ComputeSparseTable(dtf.discreteKeys(), dtf)) {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c) TableFactor::TableFactor(const DiscreteConditional& c)