From ff3994647a9999b1b2b9e0360c207ea84971f0af Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 11:20:38 -0400 Subject: [PATCH] add new TableFactor constructors --- gtsam/discrete/TableFactor.cpp | 38 +++++++++++++++++++++++- gtsam/discrete/TableFactor.h | 8 ++++- gtsam/discrete/tests/testTableFactor.cpp | 11 +++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 2be8e077d..f4e023a4d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTree& dtree) + : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} + +/** + * @brief Compute the correct ordering of the leaves in the decision tree. + * + * This is done by first taking all the values which have modulo 0 value with + * the cardinality of the innermost key `n`, and we go up to modulo n. + * + * @param dt The DecisionTree + * @return std::vector + */ +std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dt) { + std::vector probs = dt.probabilities(); + std::vector ordered; + + size_t n = dkeys[0].second; + + for (size_t k = 0; k < n; ++k) { + for (size_t idx = 0; idx < probs.size(); ++idx) { + if (idx % n == k) { + ordered.push_back(probs[idx]); + } + } + } + return ordered; +} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) - : TableFactor(c.discreteKeys(), c.probabilities()) {} + : TableFactor(c.discreteKeys(), c) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 981e1507b..828e794e6 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /// Constructor from DecisionTreeFactor + TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); + + /// Constructor from DecisionTree/AlgebraicDecisionTree + TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); + /** Construct from a DiscreteConditional type */ explicit TableFactor(const DiscreteConditional& c); @@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiple with DecisionTreeFactor + /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e85e4254c..0f7d7a615 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -131,6 +132,16 @@ TEST(TableFactor, constructors) { // Manually constructed via inspection and comparison to DecisionTreeFactor TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); EXPECT(assert_equal(expected, f4)); + + // Test for 9=3x3 values. + DiscreteKey V(0, 3), W(1, 3); + DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); + TableFactor f5(conditional5); + // GTSAM_PRINT(f5); + TableFactor expected_f5( + X & Y, + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + EXPECT(assert_equal(expected_f5, f5, 1e-6)); } /* ************************************************************************* */