diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 5fb5ae2e6..04e29024c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -181,6 +181,15 @@ namespace gtsam { return result; } + /* ************************************************************************ */ + std::vector DecisionTreeFactor::probabilities() const { + std::vector probs; + for (auto&& [key, value] : enumerate()) { + probs.push_back(value); + } + return probs; + } + /* ************************************************************************ */ DiscreteKeys DecisionTreeFactor::discreteKeys() const { DiscreteKeys result; @@ -306,11 +315,10 @@ namespace gtsam { // Get the probabilities in the decision tree so we can threshold. std::vector probabilities; - this->visitLeaf([&](const Leaf& leaf) { - size_t nrAssignments = leaf.nrAssignments(); - double prob = leaf.constant(); - probabilities.insert(probabilities.end(), nrAssignments, prob); - }); + // NOTE(Varun) this is potentially slow due to the cartesian product + for (auto&& [assignment, prob] : this->enumerate()) { + probabilities.push_back(prob); + } // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 0cfda6b7d..42639095f 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -181,6 +181,9 @@ namespace gtsam { /// Enumerate all values into a map from values to double. std::vector> enumerate() const; + /// Get all the probabilities in order of assignment values + std::vector probabilities() const; + /// Return all the discrete keys associated with this factor. DiscreteKeys discreteKeys() const; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 5fe3cd9d1..c59b7b72c 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -13,11 +13,12 @@ * @file TableFactor.cpp * @brief discrete factor * @date May 4, 2023 - * @author Yoonwoo Kim + * @author Yoonwoo Kim, Varun Agrawal */ #include #include +#include #include #include @@ -56,6 +57,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteConditional& c) + : TableFactor(c.discreteKeys(), c.probabilities()) {} + /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( const std::vector& table) { diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 1462180e0..08c675b67 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -12,7 +12,7 @@ /** * @file TableFactor.h * @date May 4, 2023 - * @author Yoonwoo Kim + * @author Yoonwoo Kim, Varun Agrawal */ #pragma once @@ -32,6 +32,7 @@ namespace gtsam { +class DiscreteConditional; class HybridValues; /** @@ -57,10 +58,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /** * @brief Uses lazy cartesian product to find nth entry in the cartesian - * product of arrays in O(1) - * Example) - * v0 | v1 | val - * 0 | 0 | 10 + * product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 * 0 | 1 | 21 * 1 | 0 | 32 * 1 | 1 | 43 @@ -75,13 +76,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { * @brief Return ith key in keys_ as a DiscreteKey * @param i ith key in keys_ * @return DiscreteKey - * */ + */ DiscreteKey discreteKey(size_t i) const { return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); } /// Convert probability table given as doubles to SparseVector. - /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} + /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} static Eigen::SparseVector Convert(const std::vector& table); /// Convert probability table given as string to SparseVector. @@ -142,6 +143,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /** Construct from a DiscreteConditional type */ + explicit TableFactor(const DiscreteConditional& c); + /// @} /// @name Testable /// @{ diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 3ad757347..b307d78f6 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -93,7 +93,8 @@ void printTime(map> for (auto&& kv : measured_time) { cout << "dropout: " << kv.first << " | TableFactor time: " << kv.second.first.count() - << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; + << " | DecisionTreeFactor time: " << kv.second.second.count() << + endl; } } @@ -124,6 +125,13 @@ TEST(TableFactor, constructors) { // Assert that error = -log(value) EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); + + // Construct from DiscreteConditional + DiscreteConditional conditional(X | Y = "1/1 2/3 1/4"); + TableFactor f4(conditional); + // Manually constructed via inspection and comparison to DecisionTreeFactor + TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); + EXPECT(assert_equal(expected, f4)); } /* ************************************************************************* */ @@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) { /* ************************************************************************* */ // Benchmark which compares runtime of multiplication of two TableFactors // and two DecisionTreeFactors given sparsity from dense to 90% sparsity. -TEST(TableFactor, benchmark) { +// NOTE: Enable to run. +TEST_DISABLED(TableFactor, benchmark) { DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);