From cd3cfa0faa5ebcacc05d7ccbdfc21bcdf505d0f9 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 3 Dec 2022 17:14:11 +0530 Subject: [PATCH] moved sequential elimination code to HybridEliminationTree --- gtsam/hybrid/HybridEliminationTree.cpp | 12 ++- gtsam/hybrid/HybridEliminationTree.h | 107 +++++++++++++++++++++ gtsam/hybrid/HybridFactor.cpp | 2 +- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 68 ------------- gtsam/hybrid/HybridGaussianFactorGraph.h | 26 ----- gtsam/hybrid/HybridSmoother.cpp | 2 +- 6 files changed, 119 insertions(+), 98 deletions(-) diff --git a/gtsam/hybrid/HybridEliminationTree.cpp b/gtsam/hybrid/HybridEliminationTree.cpp index c2df2dd60..fe9190571 100644 --- a/gtsam/hybrid/HybridEliminationTree.cpp +++ b/gtsam/hybrid/HybridEliminationTree.cpp @@ -27,12 +27,20 @@ template class EliminationTree; HybridEliminationTree::HybridEliminationTree( const HybridGaussianFactorGraph& factorGraph, const VariableIndex& structure, const Ordering& order) - : Base(factorGraph, structure, order) {} + : Base(factorGraph, structure, order), + graph_(factorGraph), + variable_index_(structure) { + // Segregate the continuous and the discrete keys + std::tie(continuous_ordering_, discrete_ordering_) = + graph_.separateContinuousDiscreteOrdering(order); +} /* ************************************************************************* */ HybridEliminationTree::HybridEliminationTree( const HybridGaussianFactorGraph& factorGraph, const Ordering& order) - : Base(factorGraph, order) {} + : Base(factorGraph, order), + graph_(factorGraph), + variable_index_(VariableIndex(factorGraph)) {} /* ************************************************************************* */ bool HybridEliminationTree::equals(const This& other, double tol) const { diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index b2dd1bd9c..dfa88bf4e 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -33,6 +33,12 @@ class GTSAM_EXPORT HybridEliminationTree private: friend class ::EliminationTreeTester; + Ordering continuous_ordering_, discrete_ordering_; + /// Used to store the original factor graph to eliminate + HybridGaussianFactorGraph graph_; + /// Store the provided variable index. + VariableIndex variable_index_; + public: typedef EliminationTree Base; ///< Base class @@ -66,6 +72,107 @@ class GTSAM_EXPORT HybridEliminationTree /** Test whether the tree is equal to another */ bool equals(const This& other, double tol = 1e-9) const; + + /** + * @brief Helper method to eliminate continuous variables. + * + * If no continuous variables exist, return an empty bayes net + * and the original graph. + * + * @param function Elimination function for hybrid elimination. + * @return std::pair, + * boost::shared_ptr > + */ + std::pair, boost::shared_ptr> + eliminateContinuous(Eliminate function) const { + if (continuous_ordering_.size() > 0) { + This continuous_etree(graph_, variable_index_, continuous_ordering_); + return continuous_etree.Base::eliminate(function); + + } else { + BayesNetType::shared_ptr bayesNet = boost::make_shared(); + FactorGraphType::shared_ptr discreteGraph = + boost::make_shared(graph_); + return std::make_pair(bayesNet, discreteGraph); + } + } + + /** + * @brief Helper method to eliminate the discrete variables after the + * continuous variables have been eliminated. + * + * If there are no discrete variables, return an empty bayes net and the + * discreteGraph which is passed in. + * + * @param function Elimination function + * @param discreteGraph The factor graph with the factor ϕ(X | M, Z). + * @return std::pair, + * boost::shared_ptr > + */ + std::pair, boost::shared_ptr> + eliminateDiscrete(Eliminate function, + const FactorGraphType::shared_ptr& discreteGraph) const { + BayesNetType::shared_ptr discreteBayesNet; + FactorGraphType::shared_ptr finalGraph; + if (discrete_ordering_.size() > 0) { + This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), + discrete_ordering_); + + std::tie(discreteBayesNet, finalGraph) = + discrete_etree.Base::eliminate(function); + + } else { + discreteBayesNet = boost::make_shared(); + finalGraph = discreteGraph; + } + + return std::make_pair(discreteBayesNet, finalGraph); + } + + /** + * @brief Override the EliminationTree eliminate method + * so we can perform hybrid elimination correctly. + * + * @param function + * @return std::pair, + * boost::shared_ptr > + */ + std::pair, boost::shared_ptr> + eliminate(Eliminate function) const { + // Perform continuous elimination + BayesNetType::shared_ptr bayesNet; + FactorGraphType::shared_ptr discreteGraph; + std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); + + // If we have eliminated continuous variables + // and have discrete variables to eliminate, + // then compute P(X | M, Z) + if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { + // Get the last continuous conditional + // which will have all the discrete keys + HybridConditional::shared_ptr last_conditional = + bayesNet->at(bayesNet->size() - 1); + DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + + // DecisionTree for P'(X|M, Z) for all mode sequences M + const AlgebraicDecisionTree probPrimeTree = + graph_.continuousProbPrimes(discrete_keys, bayesNet); + + // Add the model selection factor P(M|Z) + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + } + + // Perform discrete elimination + BayesNetType::shared_ptr discreteBayesNet; + FactorGraphType::shared_ptr finalGraph; + std::tie(discreteBayesNet, finalGraph) = + eliminateDiscrete(function, discreteGraph); + + // Add the discrete conditionals to the hybrid conditionals + bayesNet->add(*discreteBayesNet); + + return std::make_pair(bayesNet, finalGraph); + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index b25e97f05..1216fd922 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -81,7 +81,7 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { /* ************************************************************************ */ void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { - std::cout << (s.empty() ? "" : s + "\n"); + std::cout << s; if (isContinuous_) std::cout << "Continuous "; if (isDiscrete_) std::cout << "Discrete "; if (isHybrid_) std::cout << "Hybrid "; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 1d52a24af..1afe4f38a 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -550,74 +550,6 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( return std::make_pair(continuous_ordering, discrete_ordering); } -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateHybridSequential( - const boost::optional continuous, - const boost::optional discrete, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - const Ordering continuous_ordering = - continuous ? *continuous : Ordering(this->continuousKeys()); - const Ordering discrete_ordering = - discrete ? *discrete : Ordering(this->discreteKeys()); - - // Eliminate continuous - HybridBayesNet::shared_ptr bayesNet; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - std::tie(bayesNet, discreteGraph) = - BaseEliminateable::eliminatePartialSequential(continuous_ordering, - function, variableIndex); - - // Get the last continuous conditional which will have all the discrete keys - HybridConditional::shared_ptr last_conditional = - bayesNet->at(bayesNet->size() - 1); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // If no discrete variables, return the eliminated bayes net. - if (discrete_keys.size() == 0) { - return bayesNet; - } - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - this->continuousProbPrimes(discrete_keys, bayesNet); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - - // Perform discrete elimination - HybridBayesNet::shared_ptr discreteBayesNet = - discreteGraph->BaseEliminateable::eliminateSequential( - discrete_ordering, function, variableIndex); - - bayesNet->add(*discreteBayesNet); - - return bayesNet; -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateSequential( - OptionalOrderingType orderingType, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - return BaseEliminateable::eliminateSequential(orderingType, function, - variableIndex); -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateSequential( - const Ordering &ordering, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - // Segregate the continuous and the discrete keys - Ordering continuous_ordering, discrete_ordering; - std::tie(continuous_ordering, discrete_ordering) = - this->separateContinuousDiscreteOrdering(ordering); - - return this->eliminateHybridSequential(continuous_ordering, discrete_ordering, - function, variableIndex); -} - /* ************************************************************************ */ boost::shared_ptr HybridGaussianFactorGraph::eliminateHybridMultifrontal( diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 47b94070f..a0d80b154 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -302,32 +302,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph std::pair separateContinuousDiscreteOrdering( const Ordering& ordering) const; - /** - * @brief Custom elimination function which computes the correct - * continuous probabilities. Returns a bayes net. - * - * @param continuous Optional ordering for all continuous variables. - * @param discrete Optional ordering for all discrete variables. - * @return boost::shared_ptr - */ - boost::shared_ptr eliminateHybridSequential( - const boost::optional continuous = boost::none, - const boost::optional discrete = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Sequential elimination overload for hybrid - boost::shared_ptr eliminateSequential( - OptionalOrderingType orderingType = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Sequential elimination overload for hybrid - boost::shared_ptr eliminateSequential( - const Ordering& ordering, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - /** * @brief Custom elimination function which computes the correct * continuous probabilities. Returns a bayes tree. diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 12f6949ab..7400053bf 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -32,7 +32,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, addConditionals(graph, hybridBayesNet_, ordering); // Eliminate. - auto bayesNetFragment = graph.eliminateHybridSequential(); + auto bayesNetFragment = graph.eliminateSequential(); /// Prune if (maxNrLeaves) {