diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 56f1659dc..17731453a 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -82,6 +82,22 @@ namespace gtsam { ADT::print("", formatter); } + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { + // apply operand + ADT result = ADT::apply(op); + // Make a new factor + return DecisionTreeFactor(discreteKeys(), result); + } + + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + // apply operand + ADT result = ADT::apply(op); + // Make a new factor + return DecisionTreeFactor(discreteKeys(), result); + } + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, ADT::Binary op) const { @@ -101,14 +117,6 @@ namespace gtsam { return DecisionTreeFactor(keys, result); } - /* ************************************************************************ */ - DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { - // apply operand - ADT result = ADT::apply(op); - // Make a new factor - return DecisionTreeFactor(discreteKeys(), result); - } - /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( size_t nrFrontals, ADT::Binary op) const { @@ -188,10 +196,45 @@ namespace gtsam { /* ************************************************************************ */ std::vector DecisionTreeFactor::probabilities() const { + // Set of all keys + std::set allKeys(keys().begin(), keys().end()); + std::vector probs; - for (auto&& [key, value] : enumerate()) { - probs.push_back(value); - } + + /* An operation that takes each leaf probability, and computes the + * nrAssignments by checking the difference between the keys in the factor + * and the keys in the assignment. + * The nrAssignments is then used to append + * the correct number of leaf probability values to the `probs` vector + * defined above. + */ + auto op = [&](const Assignment& a, double p) { + // Get all the keys in the current assignment + std::set assignment_keys; + for (auto&& [k, _] : a) { + assignment_keys.insert(k); + } + + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); + + // Compute the total number of assignments in the (pruned) subtree + size_t nrAssignments = 1; + for (auto&& k : diff) { + nrAssignments *= cardinalities_.at(k); + } + // Add p `nrAssignments` times to the probs vector. + probs.insert(probs.end(), nrAssignments, p); + + return p; + }; + + // Go through the tree + this->apply(op); + return probs; } @@ -305,11 +348,7 @@ namespace gtsam { const size_t N = maxNrAssignments; // Get the probabilities in the decision tree so we can threshold. - std::vector probabilities; - // NOTE(Varun) this is potentially slow due to the cartesian product - for (auto&& [assignment, prob] : this->enumerate()) { - probabilities.push_back(prob); - } + std::vector probabilities = this->probabilities(); // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index c52f37818..a9a7e5ce0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -186,6 +186,13 @@ namespace gtsam { * Apply unary operator (*this) "op" f * @param op a unary operator that operates on AlgebraicDecisionTree */ + DecisionTreeFactor apply(ADT::Unary op) const; + + /** + * Apply unary operator (*this) "op" f + * @param op a unary operator that operates on AlgebraicDecisionTree. Takes + * both the assignment and the value. + */ DecisionTreeFactor apply(ADT::UnaryAssignment op) const; /** diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 2be8e077d..f4e023a4d 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys, sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); } +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTree& dtree) + : TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {} + +/** + * @brief Compute the correct ordering of the leaves in the decision tree. + * + * This is done by first taking all the values which have modulo 0 value with + * the cardinality of the innermost key `n`, and we go up to modulo n. + * + * @param dt The DecisionTree + * @return std::vector + */ +std::vector ComputeLeafOrdering(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dt) { + std::vector probs = dt.probabilities(); + std::vector ordered; + + size_t n = dkeys[0].second; + + for (size_t k = 0; k < n; ++k) { + for (size_t idx = 0; idx < probs.size(); ++idx) { + if (idx % n == k) { + ordered.push_back(probs[idx]); + } + } + } + return ordered; +} + +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const DecisionTreeFactor& dtf) + : TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {} + /* ************************************************************************ */ TableFactor::TableFactor(const DiscreteConditional& c) - : TableFactor(c.discreteKeys(), c.probabilities()) {} + : TableFactor(c.discreteKeys(), c) {} /* ************************************************************************ */ Eigen::SparseVector TableFactor::Convert( diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 981e1507b..828e794e6 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { TableFactor(const DiscreteKey& key, const std::vector& row) : TableFactor(DiscreteKeys{key}, row) {} + /// Constructor from DecisionTreeFactor + TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf); + + /// Constructor from DecisionTree/AlgebraicDecisionTree + TableFactor(const DiscreteKeys& keys, const DecisionTree& dtree); + /** Construct from a DiscreteConditional type */ explicit TableFactor(const DiscreteConditional& c); @@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiple with DecisionTreeFactor + /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; static double safe_div(const double& a, const double& b); diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index e85e4254c..0f7d7a615 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -131,6 +132,16 @@ TEST(TableFactor, constructors) { // Manually constructed via inspection and comparison to DecisionTreeFactor TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); EXPECT(assert_equal(expected, f4)); + + // Test for 9=3x3 values. + DiscreteKey V(0, 3), W(1, 3); + DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11"); + TableFactor f5(conditional5); + // GTSAM_PRINT(f5); + TableFactor expected_f5( + X & Y, + "0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667"); + EXPECT(assert_equal(expected_f5, f5, 1e-6)); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 659d44423..753e35bf0 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -286,8 +286,6 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { /* *******************************************************************************/ void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { - auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); - auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); // Functional which loops over all assignments and create a set of // GaussianConditionals auto pruner = prunerFunc(discreteProbs); diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b4bf61220..3bafe5a9c 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -129,7 +129,6 @@ std::function &, double)> prunerFunc( DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( size_t maxNrLeaves) { // Get the joint distribution of only the discrete keys - gttic_(HybridBayesNet_PruneDiscreteConditionals); // The joint discrete probability. DiscreteConditional discreteProbs; @@ -147,12 +146,11 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( discrete_factor_idxs.push_back(i); } } + const DecisionTreeFactor prunedDiscreteProbs = discreteProbs.prune(maxNrLeaves); - gttoc_(HybridBayesNet_PruneDiscreteConditionals); // Eliminate joint probability back into conditionals - gttic_(HybridBayesNet_UpdateDiscreteConditionals); DiscreteFactorGraph dfg{prunedDiscreteProbs}; DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); @@ -161,7 +159,6 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( size_t idx = discrete_factor_idxs.at(i); this->at(idx) = std::make_shared(dbn->at(i)); } - gttoc_(HybridBayesNet_UpdateDiscreteConditionals); return prunedDiscreteProbs; } @@ -180,7 +177,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet prunedBayesNetFragment; - gttic_(HybridBayesNet_PruneMixtures); // Go through all the conditionals in the // Bayes Net and prune them as per prunedDiscreteProbs. for (auto &&conditional : *this) { @@ -197,7 +193,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { prunedBayesNetFragment.push_back(conditional); } } - gttoc_(HybridBayesNet_PruneMixtures); return prunedBayesNetFragment; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2d4ac83f6..7a7ca0cbf 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian( // TODO(dellaert): it's probably more efficient to first collect the discrete // keys, and then loop over all assignments to populate a vector. GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { - gttic_(assembleGraphTree); GaussianFactorGraphTree result; @@ -129,8 +128,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { } } - gttoc_(assembleGraphTree); - return result; } diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 12506b8af..ad6e64ea0 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -420,7 +420,7 @@ TEST(HybridFactorGraph, Full_Elimination) { DiscreteFactorGraph discrete_fg; // TODO(Varun) Make this a function of HybridGaussianFactorGraph? for (auto& factor : (*remainingFactorGraph_partial)) { - auto df = dynamic_pointer_cast(factor); + auto df = dynamic_pointer_cast(factor); assert(df); discrete_fg.push_back(df); }