diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index de1e1f867..22548da07 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -153,8 +153,7 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, /* ************************************************************************ */ TableFactor::TableFactor(const DecisionTreeFactor& dtf) - : TableFactor(dtf.discreteKeys(), - ComputeSparseTable(dtf.discreteKeys(), dtf)) {} + : TableFactor(dtf.discreteKeys(), dtf) {} /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) @@ -252,12 +251,43 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); - std::vector table; + + // Record key assignment and value pairs in pair_table. + // The assignments are stored in descending order of keys so that the order of + // the values matches what is expected by a DecisionTree. + // This is why we reverse the keys and then + // query for the key value/assignment. + DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend()); + std::vector> pair_table; for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + std::stringstream ss; + for (auto&& [key, _] : rdkeys) { + ss << keyValueForIndex(key, i); + } + // k will be in reverse key order already + uint64_t k; + ss >> k; + pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i))); } - // NOTE(Varun): This constructor is really expensive!! - DecisionTreeFactor f(dkeys, table); + + // Sort the pair_table (of assignment-value pairs) based on assignment so we + // get values in reverse key order. + std::sort( + pair_table.begin(), pair_table.end(), + [](const std::pair& a, + const std::pair& b) { return a.first < b.first; }); + + // Create the table vector by extracting the values from pair_table. + // The pair_table has already been sorted in the desired order, + // so the values will be in descending key order. + std::vector table; + std::for_each(pair_table.begin(), pair_table.end(), + [&table](const std::pair& pair) { + table.push_back(pair.second); + }); + + AlgebraicDecisionTree tree(rdkeys, table); + DecisionTreeFactor f(dkeys, tree); return f; } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index d27c4740c..5ddb4ab43 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -99,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; - public: /// @name Standard Constructors /// @{ @@ -156,6 +155,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ + /// Getter for the underlying sparse vector + Eigen::SparseVector sparseTable() const { return sparse_table_; } + /// Evaluate probability distribution, is just look up in TableFactor. double evaluate(const Assignment& values) const override;