address review comments
parent
7d389a5300
commit
9830981351
|
@ -62,40 +62,55 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
|||
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
|
||||
|
||||
/**
|
||||
* @brief Compute the correct ordering of the leaves in the decision tree.
|
||||
* @brief Compute the indexing of the leaves in the decision tree based on the
|
||||
* assignment and add the (index, leaf) pair to a SparseVector.
|
||||
*
|
||||
* We visit each leaf in the tree, and using the cardinalities of the keys,
|
||||
* compute the correct index to add the leaf to a SparseVector which
|
||||
* is then used to create the TableFactor.
|
||||
*
|
||||
* @param dt The DecisionTree
|
||||
* @return Eigen::SparseVector<double>
|
||||
*/
|
||||
static Eigen::SparseVector<double> ComputeLeafOrdering(
|
||||
static Eigen::SparseVector<double> ComputeSparseTable(
|
||||
const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) {
|
||||
// SparseVector needs to know the maximum possible index,
|
||||
// so we compute the product of cardinalities.
|
||||
size_t prod_cardinality = 1;
|
||||
size_t cardinalityProduct = 1;
|
||||
for (auto&& [_, c] : dt.cardinalities()) {
|
||||
prod_cardinality *= c;
|
||||
cardinalityProduct *= c;
|
||||
}
|
||||
Eigen::SparseVector<double> sparse_table(prod_cardinality);
|
||||
Eigen::SparseVector<double> sparseTable(cardinalityProduct);
|
||||
size_t nrValues = 0;
|
||||
dt.visit([&nrValues](double x) {
|
||||
if (x > 0) nrValues += 1;
|
||||
});
|
||||
sparse_table.reserve(nrValues);
|
||||
sparseTable.reserve(nrValues);
|
||||
|
||||
std::set<Key> allKeys(dt.keys().begin(), dt.keys().end());
|
||||
|
||||
/**
|
||||
* @brief Functor which is called by the DecisionTree for each leaf.
|
||||
* For each leaf value, we use the corresponding assignment to compute a
|
||||
* corresponding index into a SparseVector. We then populate sparseTable with
|
||||
* the value at the computed index.
|
||||
*
|
||||
* Takes advantage of the sparsity of the DecisionTree to be efficient. When
|
||||
* merged branches are encountered, we enumerate over the missing keys.
|
||||
*
|
||||
*/
|
||||
auto op = [&](const Assignment<Key>& assignment, double p) {
|
||||
if (p > 0) {
|
||||
// Get all the keys involved in this assignment
|
||||
std::set<Key> assignment_keys;
|
||||
std::set<Key> assignmentKeys;
|
||||
for (auto&& [k, _] : assignment) {
|
||||
assignment_keys.insert(k);
|
||||
assignmentKeys.insert(k);
|
||||
}
|
||||
|
||||
// Find the keys missing in the assignment
|
||||
std::vector<Key> diff;
|
||||
std::set_difference(allKeys.begin(), allKeys.end(),
|
||||
assignment_keys.begin(), assignment_keys.end(),
|
||||
assignmentKeys.begin(), assignmentKeys.end(),
|
||||
std::back_inserter(diff));
|
||||
|
||||
// Generate all assignments using the missing keys
|
||||
|
@ -103,41 +118,43 @@ static Eigen::SparseVector<double> ComputeLeafOrdering(
|
|||
for (auto&& key : diff) {
|
||||
extras.push_back({key, dt.cardinality(key)});
|
||||
}
|
||||
auto&& extra_assignments = DiscreteValues::CartesianProduct(extras);
|
||||
auto&& extraAssignments = DiscreteValues::CartesianProduct(extras);
|
||||
|
||||
for (auto&& extra : extra_assignments) {
|
||||
for (auto&& extra : extraAssignments) {
|
||||
// Create new assignment using the extra assignment
|
||||
DiscreteValues updated_assignment(assignment);
|
||||
updated_assignment.insert(extra);
|
||||
DiscreteValues updatedAssignment(assignment);
|
||||
updatedAssignment.insert(extra);
|
||||
|
||||
// Generate index and add to the sparse vector.
|
||||
Eigen::Index idx = 0;
|
||||
size_t prev_cardinality = 1;
|
||||
size_t previousCardinality = 1;
|
||||
// We go in reverse since a DecisionTree has the highest label first
|
||||
for (auto&& it = updated_assignment.rbegin();
|
||||
it != updated_assignment.rend(); it++) {
|
||||
idx += prev_cardinality * it->second;
|
||||
prev_cardinality *= dt.cardinality(it->first);
|
||||
for (auto&& it = updatedAssignment.rbegin();
|
||||
it != updatedAssignment.rend(); it++) {
|
||||
idx += previousCardinality * it->second;
|
||||
previousCardinality *= dt.cardinality(it->first);
|
||||
}
|
||||
sparse_table.coeffRef(idx) = p;
|
||||
sparseTable.coeffRef(idx) = p;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Visit each leaf in `dt` to get the Assignment and leaf value
|
||||
// to populate the sparseTable.
|
||||
dt.visitWith(op);
|
||||
|
||||
return sparse_table;
|
||||
return sparseTable;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||
const DecisionTreeFactor& dtf)
|
||||
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
|
||||
: TableFactor(dkeys, ComputeSparseTable(dkeys, dtf)) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
|
||||
: TableFactor(dtf.discreteKeys(),
|
||||
ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {}
|
||||
ComputeSparseTable(dtf.discreteKeys(), dtf)) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||
|
|
Loading…
Reference in New Issue