diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 1b265d14f..f4e023a4d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,19 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } -/* ************************************************************************ */ -TableFactor::TableFactor(const DiscreteKeys& dkeys, - const DecisionTreeFactor& dtf) - : TableFactor(dkeys, dtf.probabilities()) {} - /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const DecisionTree& dtree) : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} +/** + * @brief Compute the correct ordering of the leaves in the decision tree. + * + * This is done by first taking all the values which have modulo 0 value with + * the cardinality of the innermost key `n`, and we go up to modulo n. + * + * @param dt The DecisionTree + * @return std::vector + */ +std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dt) { + std::vector probs = dt.probabilities(); + std::vector ordered; + + size_t n = dkeys[0].second; + + for (size_t k = 0; k < n; ++k) { + for (size_t idx = 0; idx < probs.size(); ++idx) { + if (idx % n == k) { + ordered.push_back(probs[idx]); + } + } + } + return ordered; +} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) - : TableFactor(c.discreteKeys(), c.probabilities()) {} + : TableFactor(c.discreteKeys(), c) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e85e4254c..0f7d7a615 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -131,6 +132,16 @@ TEST(TableFactor, constructors) { // Manually constructed via inspection and comparison to DecisionTreeFactor TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); EXPECT(assert_equal(expected, f4)); + + // Test for 9=3x3 values. + DiscreteKey V(0, 3), W(1, 3); + DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); + TableFactor f5(conditional5); + // GTSAM_PRINT(f5); + TableFactor expected_f5( + X & Y, + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + EXPECT(assert_equal(expected_f5, f5, 1e-6)); } /* ************************************************************************* */