From 0938159706f1d28e089c936f6a73834baab59367 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 9 Nov 2022 18:38:42 -0500 Subject: [PATCH] overload multifrontal elimination --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 182 +++++++++++++++++++-- gtsam/hybrid/HybridGaussianFactorGraph.h | 26 +++ 2 files changed, 192 insertions(+), 16 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 983817f03..a0c1b67da 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -546,6 +546,24 @@ HybridGaussianFactorGraph::continuousDelta( return delta_tree; } +/* ************************************************************************ */ +DecisionTree +HybridGaussianFactorGraph::continuousDelta( + const DiscreteKeys &discrete_keys, + const boost::shared_ptr &continuousBayesTree, + const std::vector &assignments) const { + // Create a decision tree of all the different VectorValues + std::vector vector_values; + for (const DiscreteValues &assignment : assignments) { + VectorValues values = continuousBayesTree->optimize(assignment); + vector_values.push_back(boost::make_shared(values)); + } + DecisionTree delta_tree(discrete_keys, + vector_values); + + return delta_tree; +} + /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( const DiscreteKeys &orig_discrete_keys, @@ -584,6 +602,67 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( return probPrimeTree; } +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( + const DiscreteKeys &orig_discrete_keys, + const boost::shared_ptr &continuousBayesTree) const { + // Generate all possible assignments. + const std::vector assignments = + DiscreteValues::CartesianProduct(orig_discrete_keys); + + // Save a copy of the original discrete key ordering + DiscreteKeys discrete_keys(orig_discrete_keys); + // Reverse discrete keys order for correct tree construction + std::reverse(discrete_keys.begin(), discrete_keys.end()); + + // Create a decision tree of all the different VectorValues + DecisionTree delta_tree = + this->continuousDelta(discrete_keys, continuousBayesTree, assignments); + + // Get the probPrime tree with the correct leaf probabilities + std::vector probPrimes; + for (const DiscreteValues &assignment : assignments) { + VectorValues delta = *delta_tree(assignment); + + // If VectorValues is empty, it means this is a pruned branch. + // Set thr probPrime to 0.0. + if (delta.size() == 0) { + probPrimes.push_back(0.0); + continue; + } + + // Compute the error given the delta and the assignment. + double error = this->error(delta, assignment); + probPrimes.push_back(exp(-error)); + } + + AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); + return probPrimeTree; +} + +/* ************************************************************************ */ +std::pair +HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( + const Ordering &ordering) const { + KeySet all_continuous_keys = this->continuousKeys(); + KeySet all_discrete_keys = this->discreteKeys(); + Ordering continuous_ordering, discrete_ordering; + + for (auto &&key : ordering) { + if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(), + key) != all_continuous_keys.end()) { + continuous_ordering.push_back(key); + } else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(), + key) != all_discrete_keys.end()) { + discrete_ordering.push_back(key); + } else { + throw std::runtime_error("Key in ordering not present in factors."); + } + } + + return std::make_pair(continuous_ordering, discrete_ordering); +} + /* ************************************************************************ */ boost::shared_ptr HybridGaussianFactorGraph::eliminateHybridSequential( @@ -640,25 +719,96 @@ boost::shared_ptr HybridGaussianFactorGraph::eliminateSequential( const Ordering &ordering, const Eliminate &function, OptionalVariableIndex variableIndex) const { - KeySet all_continuous_keys = this->continuousKeys(); - KeySet all_discrete_keys = this->discreteKeys(); - Ordering continuous_ordering, discrete_ordering; - // Segregate the continuous and the discrete keys - for (auto &&key : ordering) { - if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(), - key) != all_continuous_keys.end()) { - continuous_ordering.push_back(key); - } else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(), - key) != all_discrete_keys.end()) { - discrete_ordering.push_back(key); - } else { - throw std::runtime_error("Key in ordering not present in factors."); - } + Ordering continuous_ordering, discrete_ordering; + std::tie(continuous_ordering, discrete_ordering) = + this->separateContinuousDiscreteOrdering(ordering); + + return this->eliminateHybridSequential(continuous_ordering, discrete_ordering, + function, variableIndex); +} + +/* ************************************************************************ */ +boost::shared_ptr +HybridGaussianFactorGraph::eliminateHybridMultifrontal( + const boost::optional continuous, + const boost::optional discrete, const Eliminate &function, + OptionalVariableIndex variableIndex) const { + Ordering continuous_ordering = + continuous ? *continuous : Ordering(this->continuousKeys()); + Ordering discrete_ordering = + discrete ? *discrete : Ordering(this->discreteKeys()); + + // Eliminate continuous + HybridBayesTree::shared_ptr bayesTree; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + std::tie(bayesTree, discreteGraph) = + BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering, + function, variableIndex); + + // Get the last continuous conditional which will have all the discrete + Key last_continuous_key = + continuous_ordering.at(continuous_ordering.size() - 1); + auto last_conditional = (*bayesTree)[last_continuous_key]->conditional(); + DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + + // If not discrete variables, return the eliminated bayes net. + if (discrete_keys.size() == 0) { + return bayesTree; } - return this->eliminateHybridSequential(continuous_ordering, - discrete_ordering); + AlgebraicDecisionTree probPrimeTree = + this->continuousProbPrimes(discrete_keys, bayesTree); + + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + + auto updatedBayesTree = + discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering, + function); + + auto discrete_clique = (*updatedBayesTree)[discrete_ordering.at(0)]; + + // Set the root of the bayes tree as the discrete clique + for (auto node : bayesTree->nodes()) { + auto clique = node.second; + + if (clique->conditional()->parents() == + discrete_clique->conditional()->frontals()) { + updatedBayesTree->addClique(clique, discrete_clique); + + } else { + // Remove the clique from the children of the parents since it will get + // added again in addClique. + auto clique_it = std::find(clique->parent()->children.begin(), + clique->parent()->children.end(), clique); + clique->parent()->children.erase(clique_it); + updatedBayesTree->addClique(clique, clique->parent()); + } + } + return updatedBayesTree; +} + +/* ************************************************************************ */ +boost::shared_ptr +HybridGaussianFactorGraph::eliminateMultifrontal( + OptionalOrderingType orderingType, const Eliminate &function, + OptionalVariableIndex variableIndex) const { + return BaseEliminateable::eliminateMultifrontal(orderingType, function, + variableIndex); +} + +/* ************************************************************************ */ +boost::shared_ptr +HybridGaussianFactorGraph::eliminateMultifrontal( + const Ordering &ordering, const Eliminate &function, + OptionalVariableIndex variableIndex) const { + // Segregate the continuous and the discrete keys + Ordering continuous_ordering, discrete_ordering; + std::tie(continuous_ordering, discrete_ordering) = + this->separateContinuousDiscreteOrdering(ordering); + + return this->eliminateHybridMultifrontal( + continuous_ordering, discrete_ordering, function, variableIndex); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 88728b6bb..fb8ebbdc4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -231,6 +231,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph const DiscreteKeys& discrete_keys, const boost::shared_ptr& continuousBayesNet, const std::vector& assignments) const; + DecisionTree continuousDelta( + const DiscreteKeys& discrete_keys, + const boost::shared_ptr& continuousBayesTree, + const std::vector& assignments) const; /** * @brief Compute the unnormalized probabilities of the continuous variables @@ -244,6 +248,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree continuousProbPrimes( const DiscreteKeys& discrete_keys, const boost::shared_ptr& continuousBayesNet) const; + AlgebraicDecisionTree continuousProbPrimes( + const DiscreteKeys& discrete_keys, + const boost::shared_ptr& continuousBayesTree) const; + + std::pair separateContinuousDiscreteOrdering( + const Ordering& ordering) const; /** * @brief Custom elimination function which computes the correct @@ -269,6 +279,22 @@ class GTSAM_EXPORT HybridGaussianFactorGraph const Eliminate& function = EliminationTraitsType::DefaultEliminate, OptionalVariableIndex variableIndex = boost::none) const; + boost::shared_ptr eliminateHybridMultifrontal( + const boost::optional continuous = boost::none, + const boost::optional discrete = boost::none, + const Eliminate& function = EliminationTraitsType::DefaultEliminate, + OptionalVariableIndex variableIndex = boost::none) const; + + boost::shared_ptr eliminateMultifrontal( + OptionalOrderingType orderingType = boost::none, + const Eliminate& function = EliminationTraitsType::DefaultEliminate, + OptionalVariableIndex variableIndex = boost::none) const; + + boost::shared_ptr eliminateMultifrontal( + const Ordering& ordering, + const Eliminate& function = EliminationTraitsType::DefaultEliminate, + OptionalVariableIndex variableIndex = boost::none) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys.