update ComputeLeafOrdering to give a correct vector of values

release/4.3a0
Varun Agrawal 2024-12-13 09:34:01 -05:00
parent b91c470b69
commit a8e24efdec
1 changed files with 57 additions and 15 deletions

View File

@ -64,27 +64,69 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
/**
* @brief Compute the correct ordering of the leaves in the decision tree.
*
* This is done by first taking all the values which have modulo 0 value with
* the cardinality of the innermost key `n`, and we go up to modulo n.
*
* @param dt The DecisionTree
* @return std::vector<double>
* @return Eigen::SparseVector<double>
*/
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dt) {
std::vector<double> probs = dt.probabilities();
std::vector<double> ordered;
static Eigen::SparseVector<double> ComputeLeafOrdering(
const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) {
// SparseVector needs to know the maximum possible index,
// so we compute the product of cardinalities.
size_t prod_cardinality = 1;
for (auto&& [_, c] : dt.cardinalities()) {
prod_cardinality *= c;
}
Eigen::SparseVector<double> sparse_table(prod_cardinality);
size_t nrValues = 0;
dt.visit([&nrValues](double x) {
if (x > 0) nrValues += 1;
});
sparse_table.reserve(nrValues);
size_t n = dkeys[0].second;
std::set<Key> allKeys(dt.keys().begin(), dt.keys().end());
for (size_t k = 0; k < n; ++k) {
for (size_t idx = 0; idx < probs.size(); ++idx) {
if (idx % n == k) {
ordered.push_back(probs[idx]);
auto op = [&](const Assignment<Key>& assignment, double p) {
if (p > 0) {
// Get all the keys involved in this assignment
std::set<Key> assignment_keys;
for (auto&& [k, _] : assignment) {
assignment_keys.insert(k);
}
// Find the keys missing in the assignment
std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(),
std::back_inserter(diff));
// Generate all assignments using the missing keys
DiscreteKeys extras;
for (auto&& key : diff) {
extras.push_back({key, dt.cardinality(key)});
}
auto&& extra_assignments = DiscreteValues::CartesianProduct(extras);
for (auto&& extra : extra_assignments) {
// Create new assignment using the extra assignment
DiscreteValues updated_assignment(assignment);
updated_assignment.insert(extra);
// Generate index and add to the sparse vector.
Eigen::Index idx = 0;
size_t prev_cardinality = 1;
// We go in reverse since a DecisionTree has the highest label first
for (auto&& it = updated_assignment.rbegin();
it != updated_assignment.rend(); it++) {
idx += prev_cardinality * it->second;
prev_cardinality *= dt.cardinality(it->first);
}
sparse_table.coeffRef(idx) = p;
}
}
}
return ordered;
};
dt.visitWith(op);
return sparse_table;
}
/* ************************************************************************ */