diff --git a/gtsam/hybrid/HybridEliminationTree.cpp b/gtsam/hybrid/HybridEliminationTree.cpp index fe9190571..e54105919 100644 --- a/gtsam/hybrid/HybridEliminationTree.cpp +++ b/gtsam/hybrid/HybridEliminationTree.cpp @@ -47,4 +47,92 @@ bool HybridEliminationTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridEliminationTree::eliminateContinuous(Eliminate function) const { + if (continuous_ordering_.size() > 0) { + This continuous_etree(graph_, variable_index_, continuous_ordering_); + return continuous_etree.Base::eliminate(function); + + } else { + HybridBayesNet::shared_ptr bayesNet = boost::make_shared(); + HybridGaussianFactorGraph::shared_ptr discreteGraph = + boost::make_shared(graph_); + return std::make_pair(bayesNet, discreteGraph); + } +} + +/* ************************************************************************* */ +boost::shared_ptr +HybridEliminationTree::addProbPrimes( + const HybridBayesNet::shared_ptr& continuousBayesNet, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { + if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { + // Get the last continuous conditional + // which will have all the discrete keys + HybridConditional::shared_ptr last_conditional = + continuousBayesNet->at(continuousBayesNet->size() - 1); + DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + + // DecisionTree for P'(X|M, Z) for all mode sequences M + const AlgebraicDecisionTree probPrimeTree = + graph_.continuousProbPrimes(discrete_keys, continuousBayesNet); + + // Add the model selection factor P(M|Z) + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + } + return discreteGraph; +} + +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridEliminationTree::eliminateDiscrete( + Eliminate function, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { + HybridBayesNet::shared_ptr discreteBayesNet; + HybridGaussianFactorGraph::shared_ptr finalGraph; + if (discrete_ordering_.size() > 0) { + This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), + discrete_ordering_); + + std::tie(discreteBayesNet, finalGraph) = + discrete_etree.Base::eliminate(function); + + } else { + discreteBayesNet = boost::make_shared(); + finalGraph = discreteGraph; + } + + return std::make_pair(discreteBayesNet, finalGraph); +} + +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridEliminationTree::eliminate(Eliminate function) const { + // Perform continuous elimination + HybridBayesNet::shared_ptr bayesNet; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); + + // If we have eliminated continuous variables + // and have discrete variables to eliminate, + // then compute P(X | M, Z) + HybridGaussianFactorGraph::shared_ptr updatedDiscreteGraph = + addProbPrimes(bayesNet, discreteGraph); + + // Perform discrete elimination + HybridBayesNet::shared_ptr discreteBayesNet; + HybridGaussianFactorGraph::shared_ptr finalGraph; + std::tie(discreteBayesNet, finalGraph) = + eliminateDiscrete(function, updatedDiscreteGraph); + + // Add the discrete conditionals to the hybrid conditionals + bayesNet->add(*discreteBayesNet); + + return std::make_pair(bayesNet, finalGraph); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index 65d614ca3..9186e04a8 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -80,22 +80,12 @@ class GTSAM_EXPORT HybridEliminationTree * and the original graph. * * @param function Elimination function for hybrid elimination. - * @return std::pair, - * boost::shared_ptr > + * @return std::pair, + * boost::shared_ptr > */ - std::pair, boost::shared_ptr> - eliminateContinuous(Eliminate function) const { - if (continuous_ordering_.size() > 0) { - This continuous_etree(graph_, variable_index_, continuous_ordering_); - return continuous_etree.Base::eliminate(function); - - } else { - BayesNetType::shared_ptr bayesNet = boost::make_shared(); - FactorGraphType::shared_ptr discreteGraph = - boost::make_shared(graph_); - return std::make_pair(bayesNet, discreteGraph); - } - } + std::pair, + boost::shared_ptr> + eliminateContinuous(Eliminate function) const; /** * @brief Helper method to eliminate the discrete variables after the @@ -104,75 +94,28 @@ class GTSAM_EXPORT HybridEliminationTree * If there are no discrete variables, return an empty bayes net and the * discreteGraph which is passed in. * - * @param function Elimination function + * @param function Hybrid elimination function * @param discreteGraph The factor graph with the factor ϕ(X | M, Z). - * @return std::pair, - * boost::shared_ptr > + * @return std::pair, + * boost::shared_ptr > */ - std::pair, boost::shared_ptr> - eliminateDiscrete(Eliminate function, - const FactorGraphType::shared_ptr& discreteGraph) const { - BayesNetType::shared_ptr discreteBayesNet; - FactorGraphType::shared_ptr finalGraph; - if (discrete_ordering_.size() > 0) { - This discrete_etree(*discreteGraph, VariableIndex(*discreteGraph), - discrete_ordering_); - - std::tie(discreteBayesNet, finalGraph) = - discrete_etree.Base::eliminate(function); - - } else { - discreteBayesNet = boost::make_shared(); - finalGraph = discreteGraph; - } - - return std::make_pair(discreteBayesNet, finalGraph); - } + std::pair, + boost::shared_ptr> + eliminateDiscrete( + Eliminate function, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; /** * @brief Override the EliminationTree eliminate method * so we can perform hybrid elimination correctly. * - * @param function - * @return std::pair, - * boost::shared_ptr > + * @param function Hybrid elimination function + * @return std::pair, + * boost::shared_ptr > */ - std::pair, boost::shared_ptr> - eliminate(Eliminate function) const { - // Perform continuous elimination - BayesNetType::shared_ptr bayesNet; - FactorGraphType::shared_ptr discreteGraph; - std::tie(bayesNet, discreteGraph) = this->eliminateContinuous(function); - - // If we have eliminated continuous variables - // and have discrete variables to eliminate, - // then compute P(X | M, Z) - if (continuous_ordering_.size() > 0 && discrete_ordering_.size() > 0) { - // Get the last continuous conditional - // which will have all the discrete keys - HybridConditional::shared_ptr last_conditional = - bayesNet->at(bayesNet->size() - 1); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - graph_.continuousProbPrimes(discrete_keys, bayesNet); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - } - - // Perform discrete elimination - BayesNetType::shared_ptr discreteBayesNet; - FactorGraphType::shared_ptr finalGraph; - std::tie(discreteBayesNet, finalGraph) = - eliminateDiscrete(function, discreteGraph); - - // Add the discrete conditionals to the hybrid conditionals - bayesNet->add(*discreteBayesNet); - - return std::make_pair(bayesNet, finalGraph); - } + std::pair, + boost::shared_ptr> + eliminate(Eliminate function) const; Ordering continuousOrdering() const { return continuous_ordering_; } Ordering discreteOrdering() const { return discrete_ordering_; }