diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index de1e1f867..521ca4d47 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -252,12 +252,33 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - std::vector table; + + DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); + std::vector> pair_table; for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + std::stringstream ss; + for (auto&& [key, _] : rdkeys) { + ss << keyValueForIndex(key, i); + } + // k will be in reverse key order already + uint64_t k = std::strtoull(ss.str().c_str(), NULL, 10); + pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } - // NOTE(Varun): This constructor is really expensive!! - DecisionTreeFactor f(dkeys, table); + + // Sort based on key so we get values in reverse key order. + std::sort( + pair_table.begin(), pair_table.end(), + [](const std::pair& a, + const std::pair& b) { return a.first <= b.first; }); + + // Create the table vector + std::vector table; + std::for_each(pair_table.begin(), pair_table.end(), + [&table](const std::pair& pair) { + table.push_back(pair.second); + }); + + DecisionTreeFactor f(rdkeys, table); return f; }