From 47f39085bc12d556f547f564b591b33e5b5d4f3c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Jul 2023 16:40:27 -0400 Subject: [PATCH 01/20] small improvements --- gtsam/discrete/DecisionTreeFactor.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 6cce6e5d4..e90a2f96f 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,6 +147,9 @@ namespace gtsam { /// @name Advanced Interface /// @{ + /// Inherit all the `apply` methods from AlgebraicDecisionTree + using ADT::apply; + /** * Apply binary operator (*this) "op" f * @param f the second argument for op From baf25de68446370575acbd3b3466d1f55fc40a9c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 8 Jul 2023 13:09:35 -0400 Subject: [PATCH 02/20] initial changes --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index e14023edd..10b8de797 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -274,7 +274,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, DecisionTree probabilities(eliminationResults, probability); return { std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, probabilities)}; + std::make_shared(discreteSeparator, dtf->probabilities())}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant. From 9c88e3ed96184866217b2ecc074604b9a328d3a8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 13 Jul 2023 16:39:55 -0400 Subject: [PATCH 03/20] Use TableFactor in hybrid elimination --- gtsam/hybrid/GaussianMixture.cpp | 6 ++---- gtsam/hybrid/GaussianMixture.h | 5 +++-- gtsam/hybrid/HybridBayesNet.cpp | 2 +- gtsam/hybrid/HybridBayesNet.h | 1 - gtsam/hybrid/HybridBayesTree.cpp | 28 +++++++++++++++------------- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 659d44423..79c385262 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -234,7 +234,7 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { */ std::function &, const GaussianConditional::shared_ptr &)> -GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { +GaussianMixture::prunerFunc(const TableFactor &discreteProbs) { // Get the discrete keys as sets for the decision tree // and the gaussian mixture. auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); @@ -285,9 +285,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { } /* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); +void GaussianMixture::prune(const TableFactor &discreteProbs) { // Functional which loops over all assignments and create a set of // GaussianConditionals auto pruner = prunerFunc(discreteProbs); diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 0b68fcfd0..c193d065c 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -80,7 +81,7 @@ class GTSAM_EXPORT GaussianMixture */ std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const DecisionTreeFactor &discreteProbs); + prunerFunc(const TableFactor &discreteProbs); public: /// @name Constructors @@ -238,7 +239,7 @@ class GTSAM_EXPORT GaussianMixture * * @param discreteProbs A pruned set of probabilities for the discrete keys. */ - void prune(const DecisionTreeFactor &discreteProbs); + void prune(const TableFactor &discreteProbs); /** * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 266e02b0d..302ccfd6c 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -64,7 +64,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const DecisionTreeFactor &prunedDiscreteProbs, + const TableFactor &prunedDiscreteProbs, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 23fc4d5d3..a71768bac 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -17,7 +17,6 @@ #pragma once -#include #include #include #include diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index ae8fa0378..d71760b9a 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -175,14 +175,15 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { void HybridBayesTree::prune(const size_t maxNrLeaves) { auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - discreteProbs->root_ = prunedDiscreteProbs.root_; + // TODO(Varun) + // TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); + // discreteProbs->root_ = prunedDiscreteProbs.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteProbs; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, + TableFactor prunedDiscreteProbs; + HybridPrunerData(const TableFactor& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} @@ -210,15 +211,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { } }; - HybridPrunerData rootData(prunedDiscreteProbs, 0); - { - treeTraversal::no_op visitorPost; - // Limits OpenMP threads since we're mixing TBB and OpenMP - TbbOpenMPMixedScope threadLimiter; - treeTraversal::DepthFirstForestParallel( - *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, - visitorPost); - } + // TODO(Varun) + // HybridPrunerData rootData(prunedDiscreteProbs, 0); + // { + // treeTraversal::no_op visitorPost; + // // Limits OpenMP threads since we're mixing TBB and OpenMP + // TbbOpenMPMixedScope threadLimiter; + // treeTraversal::DepthFirstForestParallel( + // *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + // visitorPost); + // } } } // namespace gtsam From f238ba55d23d2c7ed80e9baa9ff180f53a070513 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 16 Jul 2023 22:04:09 -0400 Subject: [PATCH 04/20] TableFactor constructor from DecisionTreeFactor and AlgebraicDecisionTree --- gtsam/discrete/TableFactor.cpp | 10 ++++++++++ gtsam/discrete/TableFactor.h | 8 +++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 74eb3ddb3..9a0520a34 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,6 +56,16 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, dtf.probabilities()) {} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTree& dtree) + : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) : TableFactor(c.discreteKeys(), c.probabilities()) {} diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index bd637bb7d..103e14bff 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -141,6 +141,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /// Constructor from DecisionTreeFactor + TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); + + /// Constructor from DecisionTree/AlgebraicDecisionTree + TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); + /** Construct from a DiscreteConditional type */ explicit TableFactor(const DiscreteConditional& c); @@ -177,7 +183,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiple with DecisionTreeFactor + /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); From a581788bffa22cfbe307a0471aa9eb31308fc01b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 16 Jul 2023 22:04:48 -0400 Subject: [PATCH 05/20] simplify return --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 10b8de797..dc11cedb0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -272,9 +272,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; DecisionTree probabilities(eliminationResults, probability); - return { - std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, dtf->probabilities())}; + + return {std::make_shared(gaussianMixture), + std::make_shared(discreteSeparator, probabilities)}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant. From 8462624c572d9d0fd8ac588d60a8a2fdc8a6f9e6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 3 Jul 2023 08:50:36 -0400 Subject: [PATCH 06/20] update HybridFactorGraph wrapper --- gtsam/hybrid/hybrid.i | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 7cc93d4b0..5e0e6dca0 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -179,6 +179,7 @@ class HybridGaussianFactorGraph { void push_back(const gtsam::HybridBayesTree& bayesTree); void push_back(const gtsam::GaussianMixtureFactor* gmm); void push_back(gtsam::DecisionTreeFactor* factor); + void push_back(gtsam::TableFactor* factor); void push_back(gtsam::JacobianFactor* factor); bool empty() const; From c8e9a57cacc8cd5465569f9848e4c1c762e785b6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 11 Jul 2023 21:12:05 -0400 Subject: [PATCH 07/20] unary apply methods for TableFactor --- gtsam/discrete/TableFactor.cpp | 41 +++++++++++++++++++++++- gtsam/discrete/TableFactor.h | 28 ++++++++++++++-- gtsam/discrete/tests/testTableFactor.cpp | 36 +++++++++++++++++++-- 3 files changed, 100 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 9a0520a34..1b265d14f 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -74,7 +74,7 @@ TableFactor::TableFactor(const DiscreteConditional& c) Eigen::SparseVector TableFactor::Convert( const std::vector& table) { Eigen::SparseVector sparse_table(table.size()); - // Count number of nonzero elements in table and reserving the space. + // Count number of nonzero elements in table and reserve the space. const uint64_t nnz = std::count_if(table.begin(), table.end(), [](uint64_t i) { return i != 0; }); sparse_table.reserve(nnz); @@ -228,6 +228,45 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const { cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; } +/* ************************************************************************ */ +TableFactor TableFactor::apply(Unary op) const { + // Initialize new factor. + uint64_t cardi = 1; + for (auto [key, c] : cardinalities_) cardi *= c; + Eigen::SparseVector sparse_table(cardi); + sparse_table.reserve(sparse_table_.nonZeros()); + + // Populate + for (SparseIt it(sparse_table_); it; ++it) { + sparse_table.coeffRef(it.index()) = op(it.value()); + } + + // Free unused memory and return. + sparse_table.pruned(); + sparse_table.data().squeeze(); + return TableFactor(discreteKeys(), sparse_table); +} + +/* ************************************************************************ */ +TableFactor TableFactor::apply(UnaryAssignment op) const { + // Initialize new factor. + uint64_t cardi = 1; + for (auto [key, c] : cardinalities_) cardi *= c; + Eigen::SparseVector sparse_table(cardi); + sparse_table.reserve(sparse_table_.nonZeros()); + + // Populate + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + sparse_table.coeffRef(it.index()) = op(assignment, it.value()); + } + + // Free unused memory and return. + sparse_table.pruned(); + sparse_table.data().squeeze(); + return TableFactor(discreteKeys(), sparse_table); +} + /* ************************************************************************ */ TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { if (keys_.empty() && sparse_table_.nonZeros() == 0) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 103e14bff..828e794e6 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -93,6 +93,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef std::shared_ptr shared_ptr; typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; + using Unary = std::function; + using UnaryAssignment = + std::function&, const double&)>; using Binary = std::function; public: @@ -224,6 +227,18 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// @name Advanced Interface /// @{ + /** + * Apply unary operator `op(*this)` where `op` accepts the discrete value. + * @param op a unary operator that operates on TableFactor + */ + TableFactor apply(Unary op) const; + /** + * Apply unary operator `op(*this)` where `op` accepts the discrete assignment + * and the value at that assignment. + * @param op a unary operator that operates on TableFactor + */ + TableFactor apply(UnaryAssignment op) const; + /** * Apply binary operator (*this) "op" f * @param f the second argument for op @@ -231,10 +246,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ TableFactor apply(const TableFactor& f, Binary op) const; - /// Return keys in contract mode. + /** + * Return keys in contract mode. + * + * Modes are each of the dimensions of a sparse tensor, + * and the contract modes represent which dimensions will + * be involved in contraction (aka tensor multiplication). + */ DiscreteKeys contractDkeys(const TableFactor& f) const; - /// Return keys in free mode. + /** + * @brief Return keys in free mode which are the dimensions + * not involved in the contraction operation. + */ DiscreteKeys freeDkeys(const TableFactor& f) const; /// Return union of DiscreteKeys in two factors. diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index b307d78f6..e85e4254c 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -93,8 +93,7 @@ void printTime(map> for (auto&& kv : measured_time) { cout << "dropout: " << kv.first << " | TableFactor time: " << kv.second.first.count() - << " | DecisionTreeFactor time: " << kv.second.second.count() << - endl; + << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; } } @@ -361,6 +360,39 @@ TEST(TableFactor, htmlWithValueFormatter) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(TableFactor, Unary) { + // Declare a bunch of keys + DiscreteKey X(0, 2), Y(1, 3); + + // Create factors + TableFactor f(X & Y, "2 5 3 6 2 7"); + auto op = [](const double x) { return 2 * x; }; + auto g = f.apply(op); + + TableFactor expected(X & Y, "4 10 6 12 4 14"); + EXPECT(assert_equal(g, expected)); + + auto sq_op = [](const double x) { return x * x; }; + auto g_sq = f.apply(sq_op); + TableFactor expected_sq(X & Y, "4 25 9 36 4 49"); + EXPECT(assert_equal(g_sq, expected_sq)); +} + +/* ************************************************************************* */ +TEST(TableFactor, UnaryAssignment) { + // Declare a bunch of keys + DiscreteKey X(0, 2), Y(1, 3); + + // Create factors + TableFactor f(X & Y, "2 5 3 6 2 7"); + auto op = [](const Assignment& key, const double x) { return 2 * x; }; + auto g = f.apply(op); + + TableFactor expected(X & Y, "4 10 6 12 4 14"); + EXPECT(assert_equal(g, expected)); +} + /* ************************************************************************* */ int main() { TestResult tr; From 2b85cfedd4845e3a4329b9cf6dc0a5acb84fc499 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 22 Jul 2023 00:50:33 -0400 Subject: [PATCH 08/20] DecisionTreeFactor apply methods --- gtsam/discrete/DecisionTreeFactor.cpp | 16 ++++++++++++++++ gtsam/discrete/DecisionTreeFactor.h | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ff18268b1..4aa5e5759 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -82,6 +82,22 @@ namespace gtsam { ADT::print("", formatter); } + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { + // apply operand + ADT result = ADT::apply(op); + // Make a new factor + return DecisionTreeFactor(discreteKeys(), result); + } + + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + // apply operand + ADT result = ADT::apply(op); + // Make a new factor + return DecisionTreeFactor(discreteKeys(), result); + } + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, ADT::Binary op) const { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index e90a2f96f..77da93c5f 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,8 +147,18 @@ namespace gtsam { /// @name Advanced Interface /// @{ - /// Inherit all the `apply` methods from AlgebraicDecisionTree - using ADT::apply; + /** + * Apply unary operator (*this) "op" f + * @param op a unary operator that operates on AlgebraicDecisionTree + */ + DecisionTreeFactor apply(ADT::Unary op) const; + + /** + * Apply unary operator (*this) "op" f + * @param op a unary operator that operates on AlgebraicDecisionTree. Takes + * both the assignment and the value. + */ + DecisionTreeFactor apply(ADT::UnaryAssignment op) const; /** * Apply binary operator (*this) "op" f From 3d24d0128f71e51f889ff5b2790eb2bd79bf7a73 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 22 Jul 2023 00:51:34 -0400 Subject: [PATCH 09/20] efficient probabilities method --- gtsam/discrete/DecisionTreeFactor.cpp | 40 ++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4aa5e5759..19cc8e230 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -197,9 +197,39 @@ namespace gtsam { /* ************************************************************************ */ std::vector DecisionTreeFactor::probabilities() const { std::vector probs; - for (auto&& [key, value] : enumerate()) { - probs.push_back(value); + + // Get all the key cardinalities + std::map cardins; + for (auto [key, cardinality] : discreteKeys()) { + cardins[key] = cardinality; } + // Set of all keys + std::set allKeys(keys().begin(), keys().end()); + + // Go through the tree + std::vector ys; + this->apply([&](const Assignment a, double p) { + // Get all the keys in the current assignment + std::set assignment_keys; + for (auto&& [k, _] : a) { + assignment_keys.insert(k); + } + + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); + + // Compute the total number of assignments in the (pruned) subtree + size_t nrAssignments = 1; + for (auto&& k : diff) { + nrAssignments *= cardins.at(k); + } + probs.insert(probs.end(), nrAssignments, p); + return p; + }); + return probs; } @@ -313,11 +343,7 @@ namespace gtsam { const size_t N = maxNrAssignments; // Get the probabilities in the decision tree so we can threshold. - std::vector probabilities; - // NOTE(Varun) this is potentially slow due to the cartesian product - for (auto&& [assignment, prob] : this->enumerate()) { - probabilities.push_back(prob); - } + std::vector probabilities = this->probabilities(); // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) { From 5f83464f0d4dafd0c1c780852ebfdd47a55e4322 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 22 Jul 2023 01:09:55 -0400 Subject: [PATCH 10/20] use existing cardinalities --- gtsam/discrete/DecisionTreeFactor.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 19cc8e230..7ba4c9011 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -196,18 +196,12 @@ namespace gtsam { /* ************************************************************************ */ std::vector DecisionTreeFactor::probabilities() const { - std::vector probs; - - // Get all the key cardinalities - std::map cardins; - for (auto [key, cardinality] : discreteKeys()) { - cardins[key] = cardinality; - } // Set of all keys std::set allKeys(keys().begin(), keys().end()); + std::vector probs; + // Go through the tree - std::vector ys; this->apply([&](const Assignment a, double p) { // Get all the keys in the current assignment std::set assignment_keys; @@ -224,7 +218,7 @@ namespace gtsam { // Compute the total number of assignments in the (pruned) subtree size_t nrAssignments = 1; for (auto&& k : diff) { - nrAssignments *= cardins.at(k); + nrAssignments *= cardinalities_.at(k); } probs.insert(probs.end(), nrAssignments, p); return p; From ad84163f663cad6081c399f2735ac36ab7ffce79 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 23 Jul 2023 12:13:07 -0400 Subject: [PATCH 11/20] use discrete base class in getting discrete factors --- gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index af3a23b94..13e7e2c8f 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -420,7 +420,7 @@ TEST(HybridFactorGraph, Full_Elimination) { DiscreteFactorGraph discrete_fg; // TODO(Varun) Make this a function of HybridGaussianFactorGraph? for (auto& factor : (*remainingFactorGraph_partial)) { - auto df = dynamic_pointer_cast(factor); + auto df = dynamic_pointer_cast(factor); assert(df); discrete_fg.push_back(df); } From 52f26e3e97befe2e784509eeec07453bd4fc8ff0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 23 Jul 2023 16:32:52 -0400 Subject: [PATCH 12/20] update TableFactor to use new version of DT probabilities --- gtsam/discrete/TableFactor.cpp | 38 ++++++++++++++++++++---- gtsam/discrete/tests/testTableFactor.cpp | 11 +++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 1b265d14f..f4e023a4d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,19 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } -/* ************************************************************************ */ -TableFactor::TableFactor(const DiscreteKeys& dkeys, - const DecisionTreeFactor& dtf) - : TableFactor(dkeys, dtf.probabilities()) {} - /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteKeys& dkeys, const DecisionTree& dtree) : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} +/** + * @brief Compute the correct ordering of the leaves in the decision tree. + * + * This is done by first taking all the values which have modulo 0 value with + * the cardinality of the innermost key `n`, and we go up to modulo n. + * + * @param dt The DecisionTree + * @return std::vector + */ +std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dt) { + std::vector probs = dt.probabilities(); + std::vector ordered; + + size_t n = dkeys[0].second; + + for (size_t k = 0; k < n; ++k) { + for (size_t idx = 0; idx < probs.size(); ++idx) { + if (idx % n == k) { + ordered.push_back(probs[idx]); + } + } + } + return ordered; +} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) - : TableFactor(c.discreteKeys(), c.probabilities()) {} + : TableFactor(c.discreteKeys(), c) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e85e4254c..0f7d7a615 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -131,6 +132,16 @@ TEST(TableFactor, constructors) { // Manually constructed via inspection and comparison to DecisionTreeFactor TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); EXPECT(assert_equal(expected, f4)); + + // Test for 9=3x3 values. + DiscreteKey V(0, 3), W(1, 3); + DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); + TableFactor f5(conditional5); + // GTSAM_PRINT(f5); + TableFactor expected_f5( + X & Y, + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + EXPECT(assert_equal(expected_f5, f5, 1e-6)); } /* ************************************************************************* */ From 2df3cc80a90c860f911d213f6be939e0c97cc673 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 23 Jul 2023 17:09:51 -0400 Subject: [PATCH 13/20] undo previous changes --- gtsam/hybrid/HybridBayesTree.cpp | 28 ++++++++++------------ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 5 ++-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index d71760b9a..ae8fa0378 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -175,15 +175,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { void HybridBayesTree::prune(const size_t maxNrLeaves) { auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - // TODO(Varun) - // TableFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - // discreteProbs->root_ = prunedDiscreteProbs.root_; + DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); + discreteProbs->root_ = prunedDiscreteProbs.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - TableFactor prunedDiscreteProbs; - HybridPrunerData(const TableFactor& prunedDiscreteProbs, + DecisionTreeFactor prunedDiscreteProbs; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} @@ -211,16 +210,15 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { } }; - // TODO(Varun) - // HybridPrunerData rootData(prunedDiscreteProbs, 0); - // { - // treeTraversal::no_op visitorPost; - // // Limits OpenMP threads since we're mixing TBB and OpenMP - // TbbOpenMPMixedScope threadLimiter; - // treeTraversal::DepthFirstForestParallel( - // *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, - // visitorPost); - // } + HybridPrunerData rootData(prunedDiscreteProbs, 0); + { + treeTraversal::no_op visitorPost; + // Limits OpenMP threads since we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + treeTraversal::DepthFirstForestParallel( + *this, rootData, HybridPrunerData::AssignmentPreOrderVisitor, + visitorPost); + } } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index dc11cedb0..7ff87c347 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -273,8 +273,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors, DecisionTree probabilities(eliminationResults, probability); - return {std::make_shared(gaussianMixture), - std::make_shared(discreteSeparator, probabilities)}; + return { + std::make_shared(gaussianMixture), + std::make_shared(discreteSeparator, probabilities)}; } else { // Otherwise, we create a resulting GaussianMixtureFactor on the separator, // taking care to correct for conditional constant. From a4462a0a3eeb729502042fdcb61b5f4a89f3422a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 23 Jul 2023 17:12:27 -0400 Subject: [PATCH 14/20] undo some more --- gtsam/hybrid/GaussianMixture.cpp | 4 ++-- gtsam/hybrid/GaussianMixture.h | 5 ++--- gtsam/hybrid/HybridBayesNet.cpp | 2 +- gtsam/hybrid/HybridBayesNet.h | 4 +++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 79c385262..753e35bf0 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -234,7 +234,7 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { */ std::function &, const GaussianConditional::shared_ptr &)> -GaussianMixture::prunerFunc(const TableFactor &discreteProbs) { +GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { // Get the discrete keys as sets for the decision tree // and the gaussian mixture. auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); @@ -285,7 +285,7 @@ GaussianMixture::prunerFunc(const TableFactor &discreteProbs) { } /* *******************************************************************************/ -void GaussianMixture::prune(const TableFactor &discreteProbs) { +void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { // Functional which loops over all assignments and create a set of // GaussianConditionals auto pruner = prunerFunc(discreteProbs); diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index c193d065c..0b68fcfd0 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -81,7 +80,7 @@ class GTSAM_EXPORT GaussianMixture */ std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const TableFactor &discreteProbs); + prunerFunc(const DecisionTreeFactor &discreteProbs); public: /// @name Constructors @@ -239,7 +238,7 @@ class GTSAM_EXPORT GaussianMixture * * @param discreteProbs A pruned set of probabilities for the discrete keys. */ - void prune(const TableFactor &discreteProbs); + void prune(const DecisionTreeFactor &discreteProbs); /** * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 302ccfd6c..266e02b0d 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -64,7 +64,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const TableFactor &prunedDiscreteProbs, + const DecisionTreeFactor &prunedDiscreteProbs, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index a71768bac..825f3acfd 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -225,7 +226,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * * @param prunedDiscreteProbs */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); + void updateDiscreteConditionals( + const DecisionTreeFactor &prunedDiscreteProbs); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ From 62d020a53156948367dda1538473df7bfb2dff24 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 23 Jul 2023 17:36:36 -0400 Subject: [PATCH 15/20] remove duplicate definition --- gtsam/discrete/DecisionTreeFactor.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index f0472245b..7ba4c9011 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -117,14 +117,6 @@ namespace gtsam { return DecisionTreeFactor(keys, result); } - /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { - // apply operand - ADT result = ADT::apply(op); - // Make a new factor - return DecisionTreeFactor(discreteKeys(), result); - } - /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( size_t nrFrontals, ADT::Binary op) const { From df0c5d7ca0994260b01b851583141b4a53e13a74 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 24 Jul 2023 19:23:16 -0400 Subject: [PATCH 16/20] remove timers --- gtsam/hybrid/HybridBayesNet.cpp | 7 +------ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 3 --- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b4bf61220..3bafe5a9c 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -129,7 +129,6 @@ std::function &, double)> prunerFunc( DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( size_t maxNrLeaves) { // Get the joint distribution of only the discrete keys - gttic_(HybridBayesNet_PruneDiscreteConditionals); // The joint discrete probability. DiscreteConditional discreteProbs; @@ -147,12 +146,11 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( discrete_factor_idxs.push_back(i); } } + const DecisionTreeFactor prunedDiscreteProbs = discreteProbs.prune(maxNrLeaves); - gttoc_(HybridBayesNet_PruneDiscreteConditionals); // Eliminate joint probability back into conditionals - gttic_(HybridBayesNet_UpdateDiscreteConditionals); DiscreteFactorGraph dfg{prunedDiscreteProbs}; DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); @@ -161,7 +159,6 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( size_t idx = discrete_factor_idxs.at(i); this->at(idx) = std::make_shared(dbn->at(i)); } - gttoc_(HybridBayesNet_UpdateDiscreteConditionals); return prunedDiscreteProbs; } @@ -180,7 +177,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet prunedBayesNetFragment; - gttic_(HybridBayesNet_PruneMixtures); // Go through all the conditionals in the // Bayes Net and prune them as per prunedDiscreteProbs. for (auto &&conditional : *this) { @@ -197,7 +193,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { prunedBayesNetFragment.push_back(conditional); } } - gttoc_(HybridBayesNet_PruneMixtures); return prunedBayesNetFragment; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2d4ac83f6..7a7ca0cbf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian( // TODO(dellaert): it's probably more efficient to first collect the discrete // keys, and then loop over all assignments to populate a vector. GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { - gttic_(assembleGraphTree); GaussianFactorGraphTree result; @@ -129,8 +128,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { } } - gttoc_(assembleGraphTree); - return result; } From cb3c35b81a8b7684d5f1a07a05aba5f5f00608ac Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 11:11:55 -0400 Subject: [PATCH 17/20] refactor and better document prune method --- gtsam/discrete/DecisionTreeFactor.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 7ba4c9011..17731453a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -201,8 +201,14 @@ namespace gtsam { std::vector probs; - // Go through the tree - this->apply([&](const Assignment a, double p) { + /* An operation that takes each leaf probability, and computes the + * nrAssignments by checking the difference between the keys in the factor + * and the keys in the assignment. + * The nrAssignments is then used to append + * the correct number of leaf probability values to the `probs` vector + * defined above. + */ + auto op = [&](const Assignment& a, double p) { // Get all the keys in the current assignment std::set assignment_keys; for (auto&& [k, _] : a) { @@ -220,9 +226,14 @@ namespace gtsam { for (auto&& k : diff) { nrAssignments *= cardinalities_.at(k); } + // Add p `nrAssignments` times to the probs vector. probs.insert(probs.end(), nrAssignments, p); + return p; - }); + }; + + // Go through the tree + this->apply(op); return probs; } From 3a78499d36dc31a8512564ecf6f5e05f10beae58 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 11:12:29 -0400 Subject: [PATCH 18/20] undo TableFactor changes --- gtsam/discrete/TableFactor.cpp | 38 +----------------------- gtsam/discrete/TableFactor.h | 8 +---- gtsam/discrete/tests/testTableFactor.cpp | 10 ------- 3 files changed, 2 insertions(+), 54 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..2be8e077d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,45 +56,9 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } -/* ************************************************************************ */ -TableFactor::TableFactor(const DiscreteKeys& dkeys, - const DecisionTree& dtree) - : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} - -/** - * @brief Compute the correct ordering of the leaves in the decision tree. - * - * This is done by first taking all the values which have modulo 0 value with - * the cardinality of the innermost key `n`, and we go up to modulo n. - * - * @param dt The DecisionTree - * @return std::vector - */ -std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, - const DecisionTreeFactor& dt) { - std::vector probs = dt.probabilities(); - std::vector ordered; - - size_t n = dkeys[0].second; - - for (size_t k = 0; k < n; ++k) { - for (size_t idx = 0; idx < probs.size(); ++idx) { - if (idx % n == k) { - ordered.push_back(probs[idx]); - } - } - } - return ordered; -} - -/* ************************************************************************ */ -TableFactor::TableFactor(const DiscreteKeys& dkeys, - const DecisionTreeFactor& dtf) - : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} - /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) - : TableFactor(c.discreteKeys(), c) {} + : TableFactor(c.discreteKeys(), c.probabilities()) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 828e794e6..d5decd1a1 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -144,12 +144,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} - /// Constructor from DecisionTreeFactor - TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); - - /// Constructor from DecisionTree/AlgebraicDecisionTree - TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); - /** Construct from a DiscreteConditional type */ explicit TableFactor(const DiscreteConditional& c); @@ -181,7 +175,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const; - /// multiply two TableFactors + /// multiple two TableFactors TableFactor operator*(const TableFactor& f) const { return apply(f, Ring::mul); }; diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 0f7d7a615..7ad0ac1ff 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -132,16 +132,6 @@ TEST(TableFactor, constructors) { // Manually constructed via inspection and comparison to DecisionTreeFactor TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); EXPECT(assert_equal(expected, f4)); - - // Test for 9=3x3 values. - DiscreteKey V(0, 3), W(1, 3); - DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); - TableFactor f5(conditional5); - // GTSAM_PRINT(f5); - TableFactor expected_f5( - X & Y, - "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); - EXPECT(assert_equal(expected_f5, f5, 1e-6)); } /* ************************************************************************* */ From 8c9fad8cb16db195114a064a1a0e1c0d30002a93 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 11:13:25 -0400 Subject: [PATCH 19/20] undo more changes in TableFactor --- gtsam/discrete/TableFactor.h | 4 ++-- gtsam/discrete/tests/testTableFactor.cpp | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index d5decd1a1..981e1507b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -175,12 +175,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const; - /// multiple two TableFactors + /// multiply two TableFactors TableFactor operator*(const TableFactor& f) const { return apply(f, Ring::mul); }; - /// multiply with DecisionTreeFactor + /// multiple with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index 7ad0ac1ff..e85e4254c 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include From ff3994647a9999b1b2b9e0360c207ea84971f0af Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 11:20:38 -0400 Subject: [PATCH 20/20] add new TableFactor constructors --- gtsam/discrete/TableFactor.cpp | 38 +++++++++++++++++++++++- gtsam/discrete/TableFactor.h | 8 ++++- gtsam/discrete/tests/testTableFactor.cpp | 11 +++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 2be8e077d..f4e023a4d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTree& dtree) + : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} + +/** + * @brief Compute the correct ordering of the leaves in the decision tree. + * + * This is done by first taking all the values which have modulo 0 value with + * the cardinality of the innermost key `n`, and we go up to modulo n. + * + * @param dt The DecisionTree + * @return std::vector + */ +std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dt) { + std::vector probs = dt.probabilities(); + std::vector ordered; + + size_t n = dkeys[0].second; + + for (size_t k = 0; k < n; ++k) { + for (size_t idx = 0; idx < probs.size(); ++idx) { + if (idx % n == k) { + ordered.push_back(probs[idx]); + } + } + } + return ordered; +} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) - : TableFactor(c.discreteKeys(), c.probabilities()) {} + : TableFactor(c.discreteKeys(), c) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 981e1507b..828e794e6 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /// Constructor from DecisionTreeFactor + TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); + + /// Constructor from DecisionTree/AlgebraicDecisionTree + TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); + /** Construct from a DiscreteConditional type */ explicit TableFactor(const DiscreteConditional& c); @@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiple with DecisionTreeFactor + /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e85e4254c..0f7d7a615 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -131,6 +132,16 @@ TEST(TableFactor, constructors) { // Manually constructed via inspection and comparison to DecisionTreeFactor TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); EXPECT(assert_equal(expected, f4)); + + // Test for 9=3x3 values. + DiscreteKey V(0, 3), W(1, 3); + DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); + TableFactor f5(conditional5); + // GTSAM_PRINT(f5); + TableFactor expected_f5( + X & Y, + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + EXPECT(assert_equal(expected_f5, f5, 1e-6)); } /* ************************************************************************* */