diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h index 9186e04a8..9b6854026 100644 --- a/gtsam/hybrid/HybridEliminationTree.h +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -73,6 +73,7 @@ class GTSAM_EXPORT HybridEliminationTree /** Test whether the tree is equal to another */ bool equals(const This& other, double tol = 1e-9) const; + protected: /** * @brief Helper method to eliminate continuous variables. * @@ -87,6 +88,22 @@ class GTSAM_EXPORT HybridEliminationTree boost::shared_ptr> eliminateContinuous(Eliminate function) const; + /** + * @brief Compute the unnormalized probability P'(X | M, Z) + * for the factor graph in each leaf of the discrete tree. + * The discrete decision tree formed as a result is added to the + * `discreteGraph` for discrete elimination. + * + * @param continuousBayesNet The bayes nets corresponding to + * the eliminated continuous variables. + * @param discreteGraph Factor graph consisting of factors + * on discrete variables only. + * @return boost::shared_ptr + */ + boost::shared_ptr addProbPrimes( + const HybridBayesNet::shared_ptr& continuousBayesNet, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; + /** * @brief Helper method to eliminate the discrete variables after the * continuous variables have been eliminated. @@ -105,6 +122,7 @@ class GTSAM_EXPORT HybridEliminationTree Eliminate function, const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; + public: /** * @brief Override the EliminationTree eliminate method * so we can perform hybrid elimination correctly. diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index f233d4bef..43b043e30 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -252,11 +252,11 @@ HybridJunctionTree::eliminateDiscrete( /* ************************************************************************* */ boost::shared_ptr HybridJunctionTree::addProbPrimes( - const HybridGaussianFactorGraph& graph, const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph, - const Ordering& continuous_ordering, - const Ordering& discrete_ordering) const { + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const { + Ordering continuous_ordering = etree_.continuousOrdering(); + Ordering discrete_ordering = etree_.discreteOrdering(); + // If we have eliminated continuous variables // and have discrete variables to eliminate, // then compute P(X | M, Z) @@ -272,6 +272,8 @@ boost::shared_ptr HybridJunctionTree::addProbPrimes( std::set dkeys_set(discrete_keys.begin(), discrete_keys.end()); discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end()); + FactorGraphType graph = etree_.graph(); + // DecisionTree for P'(X|M, Z) for all mode sequences M const AlgebraicDecisionTree probPrimeTree = graph.continuousProbPrimes(discrete_keys, continuousBayesTree); @@ -298,8 +300,7 @@ HybridJunctionTree::eliminate(const Eliminate& function) const { this->eliminateContinuous(function, graph, continuous_ordering); FactorGraphType::shared_ptr updatedDiscreteGraph = - this->addProbPrimes(graph, continuousBayesTree, discreteGraph, - continuous_ordering, discrete_ordering); + this->addProbPrimes(continuousBayesTree, discreteGraph); // Eliminate discrete variables to get the discrete bayes tree. return this->eliminateDiscrete(function, continuousBayesTree, diff --git a/gtsam/hybrid/HybridJunctionTree.h b/gtsam/hybrid/HybridJunctionTree.h index 2dc13d5e3..d0473c33d 100644 --- a/gtsam/hybrid/HybridJunctionTree.h +++ b/gtsam/hybrid/HybridJunctionTree.h @@ -73,14 +73,15 @@ class GTSAM_EXPORT HybridJunctionTree */ HybridJunctionTree(const HybridEliminationTree& eliminationTree); + protected: /** - * @brief - * - * @param function - * @param graph - * @param continuous_ordering + * @brief Eliminate all the continuous variables from the factor graph. + * + * @param function The hybrid elimination function. + * @param graph The factor graph to eliminate. + * @param continuous_ordering The ordering of continuous variables. * @return std::pair, - * boost::shared_ptr> + * boost::shared_ptr> */ std::pair, boost::shared_ptr> @@ -89,14 +90,17 @@ class GTSAM_EXPORT HybridJunctionTree const Ordering& continuous_ordering) const; /** - * @brief - * - * @param function - * @param continuousBayesTree - * @param discreteGraph - * @param discrete_ordering + * @brief Eliminate all the discrete variables in the hybrid factor graph. + * + * @param function The hybrid elimination function. + * @param continuousBayesTree The bayes tree corresponding to + * the eliminated continuous variables. + * @param discreteGraph Factor graph of factors containing + * only discrete variables. + * @param discrete_ordering The elimination ordering for + * the discrete variables. * @return std::pair, - * boost::shared_ptr> + * boost::shared_ptr> */ std::pair, boost::shared_ptr> @@ -106,28 +110,29 @@ class GTSAM_EXPORT HybridJunctionTree const Ordering& discrete_ordering) const; /** - * @brief - * - * @param graph - * @param continuousBayesTree - * @param discreteGraph - * @param continuous_ordering - * @param discrete_ordering - * @return boost::shared_ptr + * @brief Compute the unnormalized probability P'(X | M, Z) + * for the factor graph in each leaf of the discrete tree. + * The discrete decision tree formed as a result is added to the + * `discreteGraph` for discrete elimination. + * + * @param continuousBayesTree The bayes tree corresponding to + * the eliminated continuous variables. + * @param discreteGraph Factor graph consisting of factors + * on discrete variables only. + * @return boost::shared_ptr */ boost::shared_ptr addProbPrimes( - const HybridGaussianFactorGraph& graph, const HybridBayesTree::shared_ptr& continuousBayesTree, - const HybridGaussianFactorGraph::shared_ptr& discreteGraph, - const Ordering& continuous_ordering, - const Ordering& discrete_ordering) const; + const HybridGaussianFactorGraph::shared_ptr& discreteGraph) const; + public: /** - * @brief - * - * @param function + * @brief Override the eliminate method so we can + * perform hybrid elimination correctly. + * + * @param function The hybrid elimination function. * @return std::pair, - * boost::shared_ptr> + * boost::shared_ptr> */ std::pair, boost::shared_ptr>