diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 1afe4f38a..c430fac2c 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -550,98 +550,5 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering( return std::make_pair(continuous_ordering, discrete_ordering); } -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateHybridMultifrontal( - const boost::optional continuous, - const boost::optional discrete, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - const Ordering continuous_ordering = - continuous ? *continuous : Ordering(this->continuousKeys()); - const Ordering discrete_ordering = - discrete ? *discrete : Ordering(this->discreteKeys()); - - // Eliminate continuous - HybridBayesTree::shared_ptr bayesTree; - HybridGaussianFactorGraph::shared_ptr discreteGraph; - std::tie(bayesTree, discreteGraph) = - BaseEliminateable::eliminatePartialMultifrontal(continuous_ordering, - function, variableIndex); - - // Get the last continuous conditional which will have all the discrete - const Key last_continuous_key = continuous_ordering.back(); - HybridConditional::shared_ptr last_conditional = - (*bayesTree)[last_continuous_key]->conditional(); - DiscreteKeys discrete_keys = last_conditional->discreteKeys(); - - // If not discrete variables, return the eliminated bayes net. - if (discrete_keys.size() == 0) { - return bayesTree; - } - - // DecisionTree for P'(X|M, Z) for all mode sequences M - const AlgebraicDecisionTree probPrimeTree = - this->continuousProbPrimes(discrete_keys, bayesTree); - - // Add the model selection factor P(M|Z) - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); - - // Eliminate discrete variables to get the discrete bayes tree. - // This bayes tree will be updated with the - // continuous variables as the child nodes. - HybridBayesTree::shared_ptr updatedBayesTree = - discreteGraph->BaseEliminateable::eliminateMultifrontal(discrete_ordering, - function); - - // Get the clique with all the discrete keys. - // There should only be 1 clique. - const HybridBayesTree::sharedClique discrete_clique = - (*updatedBayesTree)[discrete_ordering.at(0)]; - - std::set clique_set; - for (auto node : bayesTree->nodes()) { - clique_set.insert(node.second); - } - - // Set the root of the bayes tree as the discrete clique - for (auto clique : clique_set) { - if (clique->conditional()->parents() == - discrete_clique->conditional()->frontals()) { - updatedBayesTree->addClique(clique, discrete_clique); - - } else { - // Remove the clique from the children of the parents since it will get - // added again in addClique. - auto clique_it = std::find(clique->parent()->children.begin(), - clique->parent()->children.end(), clique); - clique->parent()->children.erase(clique_it); - updatedBayesTree->addClique(clique, clique->parent()); - } - } - return updatedBayesTree; -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateMultifrontal( - OptionalOrderingType orderingType, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - return BaseEliminateable::eliminateMultifrontal(orderingType, function, - variableIndex); -} - -/* ************************************************************************ */ -boost::shared_ptr -HybridGaussianFactorGraph::eliminateMultifrontal( - const Ordering &ordering, const Eliminate &function, - OptionalVariableIndex variableIndex) const { - // Segregate the continuous and the discrete keys - Ordering continuous_ordering, discrete_ordering; - std::tie(continuous_ordering, discrete_ordering) = - this->separateContinuousDiscreteOrdering(ordering); - - return this->eliminateHybridMultifrontal( - continuous_ordering, discrete_ordering, function, variableIndex); -} } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index a0d80b154..210ce50e9 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -302,31 +302,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph std::pair separateContinuousDiscreteOrdering( const Ordering& ordering) const; - /** - * @brief Custom elimination function which computes the correct - * continuous probabilities. Returns a bayes tree. - * - * @param continuous Optional ordering for all continuous variables. - * @param discrete Optional ordering for all discrete variables. - * @return boost::shared_ptr - */ - boost::shared_ptr eliminateHybridMultifrontal( - const boost::optional continuous = boost::none, - const boost::optional discrete = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Multifrontal elimination overload for hybrid - boost::shared_ptr eliminateMultifrontal( - OptionalOrderingType orderingType = boost::none, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; - - /// Multifrontal elimination overload for hybrid - boost::shared_ptr eliminateMultifrontal( - const Ordering& ordering, - const Eliminate& function = EliminationTraitsType::DefaultEliminate, - OptionalVariableIndex variableIndex = boost::none) const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 422c200a4..f233d4bef 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -142,7 +142,8 @@ struct HybridConstructorTraversalData { /* ************************************************************************* */ HybridJunctionTree::HybridJunctionTree( - const HybridEliminationTree& eliminationTree) { + const HybridEliminationTree& eliminationTree) + : etree_(eliminationTree) { gttic(JunctionTree_FromEliminationTree); // Here we rely on the BayesNet having been produced by this elimination tree, // such that the conditionals are arranged in DFS post-order. We traverse the @@ -169,4 +170,140 @@ HybridJunctionTree::HybridJunctionTree( Base::remainingFactors_ = eliminationTree.remainingFactors(); } +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridJunctionTree::eliminateContinuous( + const Eliminate& function, const HybridGaussianFactorGraph& graph, + const Ordering& continuous_ordering) const { + HybridBayesTree::shared_ptr continuousBayesTree; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + + if (continuous_ordering.size() > 0) { + HybridEliminationTree continuous_etree(graph, etree_.variableIndex(), + continuous_ordering); + + This continuous_junction_tree(continuous_etree); + std::tie(continuousBayesTree, discreteGraph) = + continuous_junction_tree.Base::eliminate(function); + + } else { + continuousBayesTree = boost::make_shared(); + discreteGraph = boost::make_shared(graph); + } + + return std::make_pair(continuousBayesTree, discreteGraph); +} +/* ************************************************************************* */ +std::pair, + boost::shared_ptr> +HybridJunctionTree::eliminateDiscrete( + const Eliminate& function, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& discrete_ordering) const { + HybridBayesTree::shared_ptr updatedBayesTree; + HybridGaussianFactorGraph::shared_ptr finalGraph; + if (discrete_ordering.size() > 0) { + HybridEliminationTree discrete_etree( + *discreteGraph, VariableIndex(*discreteGraph), discrete_ordering); + + This discrete_junction_tree(discrete_etree); + + std::tie(updatedBayesTree, finalGraph) = + discrete_junction_tree.Base::eliminate(function); + + // Get the clique with all the discrete keys. + // There should only be 1 clique. + const HybridBayesTree::sharedClique discrete_clique = + (*updatedBayesTree)[discrete_ordering.at(0)]; + + std::set clique_set; + for (auto node : continuousBayesTree->nodes()) { + clique_set.insert(node.second); + } + + // Set the root of the bayes tree as the discrete clique + for (auto clique : clique_set) { + if (clique->conditional()->parents() == + discrete_clique->conditional()->frontals()) { + updatedBayesTree->addClique(clique, discrete_clique); + + } else { + if (clique->parent()) { + // Remove the clique from the children of the parents since it will + // get added again in addClique. + auto clique_it = std::find(clique->parent()->children.begin(), + clique->parent()->children.end(), clique); + clique->parent()->children.erase(clique_it); + updatedBayesTree->addClique(clique, clique->parent()); + } else { + updatedBayesTree->addClique(clique); + } + } + } + } else { + updatedBayesTree = continuousBayesTree; + finalGraph = discreteGraph; + } + + return std::make_pair(updatedBayesTree, finalGraph); +} + +/* ************************************************************************* */ +boost::shared_ptr HybridJunctionTree::addProbPrimes( + const HybridGaussianFactorGraph& graph, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& continuous_ordering, + const Ordering& discrete_ordering) const { + // If we have eliminated continuous variables + // and have discrete variables to eliminate, + // then compute P(X | M, Z) + if (continuous_ordering.size() > 0 && discrete_ordering.size() > 0) { + // Collect all the discrete keys + DiscreteKeys discrete_keys; + for (auto node : continuousBayesTree->nodes()) { + auto node_dkeys = node.second->conditional()->discreteKeys(); + discrete_keys.insert(discrete_keys.end(), node_dkeys.begin(), + node_dkeys.end()); + } + // Remove duplicates and convert back to DiscreteKeys + std::set dkeys_set(discrete_keys.begin(), discrete_keys.end()); + discrete_keys = DiscreteKeys(dkeys_set.begin(), dkeys_set.end()); + + // DecisionTree for P'(X|M, Z) for all mode sequences M + const AlgebraicDecisionTree probPrimeTree = + graph.continuousProbPrimes(discrete_keys, continuousBayesTree); + + // Add the model selection factor P(M|Z) + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + } + + return discreteGraph; +} + +/* ************************************************************************* */ +std::pair +HybridJunctionTree::eliminate(const Eliminate& function) const { + Ordering continuous_ordering = etree_.continuousOrdering(); + Ordering discrete_ordering = etree_.discreteOrdering(); + + FactorGraphType graph = etree_.graph(); + + // Eliminate continuous + BayesTreeType::shared_ptr continuousBayesTree; + FactorGraphType::shared_ptr discreteGraph; + std::tie(continuousBayesTree, discreteGraph) = + this->eliminateContinuous(function, graph, continuous_ordering); + + FactorGraphType::shared_ptr updatedDiscreteGraph = + this->addProbPrimes(graph, continuousBayesTree, discreteGraph, + continuous_ordering, discrete_ordering); + + // Eliminate discrete variables to get the discrete bayes tree. + return this->eliminateDiscrete(function, continuousBayesTree, + updatedDiscreteGraph, discrete_ordering); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.h b/gtsam/hybrid/HybridJunctionTree.h index 4b0c369a8..2dc13d5e3 100644 --- a/gtsam/hybrid/HybridJunctionTree.h +++ b/gtsam/hybrid/HybridJunctionTree.h @@ -51,10 +51,15 @@ class HybridEliminationTree; */ class GTSAM_EXPORT HybridJunctionTree : public JunctionTree { + /// Record the elimination tree for use in hybrid elimination. + HybridEliminationTree etree_; + /// Store the provided variable index. + VariableIndex variable_index_; + public: typedef JunctionTree Base; ///< Base class - typedef HybridJunctionTree This; ///< This class + typedef HybridJunctionTree This; ///< This class typedef boost::shared_ptr shared_ptr; ///< Shared pointer to this class /** @@ -67,6 +72,66 @@ class GTSAM_EXPORT HybridJunctionTree * @return The elimination tree */ HybridJunctionTree(const HybridEliminationTree& eliminationTree); + + /** + * @brief + * + * @param function + * @param graph + * @param continuous_ordering + * @return std::pair, + * boost::shared_ptr> + */ + std::pair, + boost::shared_ptr> + eliminateContinuous(const Eliminate& function, + const HybridGaussianFactorGraph& graph, + const Ordering& continuous_ordering) const; + + /** + * @brief + * + * @param function + * @param continuousBayesTree + * @param discreteGraph + * @param discrete_ordering + * @return std::pair, + * boost::shared_ptr> + */ + std::pair, + boost::shared_ptr> + eliminateDiscrete(const Eliminate& function, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& discrete_ordering) const; + + /** + * @brief + * + * @param graph + * @param continuousBayesTree + * @param discreteGraph + * @param continuous_ordering + * @param discrete_ordering + * @return boost::shared_ptr + */ + boost::shared_ptr addProbPrimes( + const HybridGaussianFactorGraph& graph, + const HybridBayesTree::shared_ptr& continuousBayesTree, + const HybridGaussianFactorGraph::shared_ptr& discreteGraph, + const Ordering& continuous_ordering, + const Ordering& discrete_ordering) const; + + /** + * @brief + * + * @param function + * @return std::pair, + * boost::shared_ptr> + */ + std::pair, + boost::shared_ptr> + eliminate(const Eliminate& function) const; }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 248d71d5f..6288bcd93 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -182,8 +182,9 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { boost::make_shared(X(1), I_3x3, Vector3::Ones())})); hfg.add(DecisionTreeFactor(m1, {2, 8})); - //TODO(Varun) Adding extra discrete variable not connected to continuous variable throws segfault - // hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); + // TODO(Varun) Adding extra discrete variable not connected to continuous + // variable throws segfault + // hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4")); HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal(hfg.getHybridOrdering()); @@ -276,7 +277,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full); // 9 cliques in the bayes tree and 0 remaining variables to eliminate. - EXPECT_LONGS_EQUAL(9, hbt->size()); + EXPECT_LONGS_EQUAL(7, hbt->size()); EXPECT_LONGS_EQUAL(0, remaining->size()); /*