update toDecisionTreeFactor to use reverse key format so it's faster

release/4.3a0
Varun Agrawal 2024-12-15 12:30:11 -05:00
parent 039c9b1542
commit 293c29ebf8
1 changed files with 25 additions and 4 deletions

View File

@ -252,12 +252,33 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
std::vector<double> table;
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
std::vector<std::pair<uint64_t, double>> pair_table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
std::stringstream ss;
for (auto&& [key, _] : rdkeys) {
ss << keyValueForIndex(key, i);
}
// k will be in reverse key order already
uint64_t k = std::strtoull(ss.str().c_str(), NULL, 10);
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
}
// NOTE(Varun): This constructor is really expensive!!
DecisionTreeFactor f(dkeys, table);
// Sort based on key so we get values in reverse key order.
std::sort(
pair_table.begin(), pair_table.end(),
[](const std::pair<uint64_t, double>& a,
const std::pair<uint64_t, double>& b) { return a.first <= b.first; });
// Create the table vector
std::vector<double> table;
std::for_each(pair_table.begin(), pair_table.end(),
[&table](const std::pair<uint64_t, double>& pair) {
table.push_back(pair.second);
});
DecisionTreeFactor f(rdkeys, table);
return f;
}