diff --git a/gtsam/inference/GenericMultifrontalSolver-inl.h b/gtsam/inference/GenericMultifrontalSolver-inl.h index bd916fbe2..a202a4700 100644 --- a/gtsam/inference/GenericMultifrontalSolver-inl.h +++ b/gtsam/inference/GenericMultifrontalSolver-inl.h @@ -23,64 +23,71 @@ #include #include -#include - using namespace std; namespace gtsam { -/* ************************************************************************* */ -template -GenericMultifrontalSolver::GenericMultifrontalSolver(const FactorGraph& factorGraph) : - structure_(new VariableIndex(factorGraph)), junctionTree_(new JUNCTIONTREE(factorGraph, *structure_)) {} + /* ************************************************************************* */ + template + GenericMultifrontalSolver::GenericMultifrontalSolver( + const FactorGraph& graph) : + structure_(new VariableIndex(graph)), junctionTree_( + new JT(graph, *structure_)) { + } -/* ************************************************************************* */ -template -GenericMultifrontalSolver::GenericMultifrontalSolver( - const typename FactorGraph::shared_ptr& factorGraph, const VariableIndex::shared_ptr& variableIndex) : - structure_(variableIndex), junctionTree_(new JUNCTIONTREE(*factorGraph, *structure_)) {} + /* ************************************************************************* */ + template + GenericMultifrontalSolver::GenericMultifrontalSolver( + const typename FactorGraph::shared_ptr& graph, + const VariableIndex::shared_ptr& variableIndex) : + structure_(variableIndex), junctionTree_(new JT(*graph, *structure_)) { + } -/* ************************************************************************* */ -template -void GenericMultifrontalSolver::replaceFactors(const typename FactorGraph::shared_ptr& factorGraph) { - junctionTree_.reset(new JUNCTIONTREE(*factorGraph, *structure_)); -} + /* ************************************************************************* */ + template + void GenericMultifrontalSolver::replaceFactors( + const typename FactorGraph::shared_ptr& graph) { + junctionTree_.reset(new JT(*graph, *structure_)); + } -/* ************************************************************************* */ -template -typename JUNCTIONTREE::BayesTree::shared_ptr GenericMultifrontalSolver< - FACTOR, JUNCTIONTREE>::eliminate( - typename FactorGraph::Eliminate function) const { - typename JUNCTIONTREE::BayesTree::shared_ptr bayesTree( - new typename JUNCTIONTREE::BayesTree); - bayesTree->insert(junctionTree_->eliminate(function)); - return bayesTree; -} + /* ************************************************************************* */ + template + typename JT::BayesTree::shared_ptr GenericMultifrontalSolver::eliminate( + typename FactorGraph::Eliminate function) const { -/* ************************************************************************* */ -template -typename FactorGraph::shared_ptr GenericMultifrontalSolver::jointFactorGraph(const std::vector& js, - Eliminate function) const { + // eliminate junction tree, returns pointer to root + typename JT::BayesTree::sharedClique root = junctionTree_->eliminate(function); - // We currently have code written only for computing the + // create an empty Bayes tree and insert root clique + typename JT::BayesTree::shared_ptr bayesTree(new typename JT::BayesTree); + bayesTree->insert(root); - if (js.size() != 2) throw domain_error( - "*MultifrontalSolver::joint(js) currently can only compute joint marginals\n" - "for exactly two variables. You can call marginal to compute the\n" - "marginal for one variable. *SequentialSolver::joint(js) can compute the\n" - "joint marginal over any number of variables, so use that if necessary.\n"); + // return the Bayes tree + return bayesTree; + } - return eliminate(function)->joint(js[0], js[1], function); -} + /* ************************************************************************* */ + template + typename FactorGraph::shared_ptr GenericMultifrontalSolver::jointFactorGraph( + const std::vector& js, Eliminate function) const { -/* ************************************************************************* */ -template -typename FACTOR::shared_ptr GenericMultifrontalSolver::marginalFactor( - Index j, Eliminate function) const { - return eliminate(function)->marginalFactor(j, function); -} + // We currently have code written only for computing the + + if (js.size() != 2) throw domain_error( + "*MultifrontalSolver::joint(js) currently can only compute joint marginals\n" + "for exactly two variables. You can call marginal to compute the\n" + "marginal for one variable. *SequentialSolver::joint(js) can compute the\n" + "joint marginal over any number of variables, so use that if necessary.\n"); + + return eliminate(function)->joint(js[0], js[1], function); + } + + /* ************************************************************************* */ + template + typename F::shared_ptr GenericMultifrontalSolver::marginalFactor( + Index j, Eliminate function) const { + return eliminate(function)->marginalFactor(j, function); + } } -