undo change to TableFactor::toDecisionTreeFactor since it is incorrect for certain larger scale cases

release/4.3a0
Varun Agrawal 2024-12-25 13:27:54 -05:00
parent c2e8867e82
commit 3694f7aeb3
1 changed files with 3 additions and 32 deletions

View File

@ -252,41 +252,12 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();
// Record key assignment and value pairs in pair_table. std::vector<double> table;
// The assignments are stored in descending order of keys so that the order of
// the values matches what is expected by a DecisionTree.
// This is why we reverse the keys and then
// query for the key value/assignment.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
std::vector<std::pair<uint64_t, double>> pair_table;
for (auto i = 0; i < sparse_table_.size(); i++) { for (auto i = 0; i < sparse_table_.size(); i++) {
std::stringstream ss; table.push_back(sparse_table_.coeff(i));
for (auto&& [key, _] : rdkeys) {
ss << keyValueForIndex(key, i);
}
// k will be in reverse key order already
uint64_t k;
ss >> k;
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
} }
// Sort the pair_table (of assignment-value pairs) based on assignment so we AlgebraicDecisionTree<Key> tree(dkeys, table);
// get values in reverse key order.
std::sort(
pair_table.begin(), pair_table.end(),
[](const std::pair<uint64_t, double>& a,
const std::pair<uint64_t, double>& b) { return a.first < b.first; });
// Create the table vector by extracting the values from pair_table.
// The pair_table has already been sorted in the desired order,
// so the values will be in descending key order.
std::vector<double> table;
std::for_each(pair_table.begin(), pair_table.end(),
[&table](const std::pair<uint64_t, double>& pair) {
table.push_back(pair.second);
});
AlgebraicDecisionTree<Key> tree(rdkeys, table);
DecisionTreeFactor f(dkeys, tree); DecisionTreeFactor f(dkeys, tree);
return f; return f;
} }