diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 23148cea8..06e4d5053 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -234,15 +234,56 @@ namespace gtsam { } } - /* ************************************************************************* */ + /* ************************************************************************* */ template - BayesTree::BayesTree() { + void BayesTree::recursiveTreeBuild(const boost::shared_ptr >& symbolic, + const std::vector >& conditionals, + const typename BayesTree::sharedClique& parent) { + + // Helper function to build a non-symbolic tree (e.g. Gaussian) using a + // symbolic tree, used in the BT(BN) constructor. + + // Build the current clique + FastList cliqueConditionals; + BOOST_FOREACH(Index j, symbolic->conditional()->frontals()) { + cliqueConditionals.push_back(conditionals[j]); } + typename BayesTree::sharedClique thisClique(new CLIQUE(CONDITIONAL::Combine(cliqueConditionals.begin(), cliqueConditionals.end()))); + + // Add the new clique with the current parent + this->addClique(thisClique, parent); + + // Build the children, whose parent is the new clique + BOOST_FOREACH(const BayesTree::sharedClique& child, symbolic->children()) { + this->recursiveTreeBuild(child, conditionals, thisClique); } + } + + /* ************************************************************************* */ + template + BayesTree::BayesTree(const BayesNet& bayesNet) { + // First generate symbolic BT to determine clique structure + BayesTree sbt(bayesNet); + + // Build index of variables to conditionals + std::vector > conditionals(sbt.root()->conditional()->frontals().back() + 1); + BOOST_FOREACH(const boost::shared_ptr& c, bayesNet) { + if(c->nrFrontals() != 1) + throw std::invalid_argument("BayesTree constructor from BayesNet only supports single frontal variable conditionals"); + if(c->firstFrontalKey() >= conditionals.size()) + throw std::invalid_argument("An inconsistent BayesNet was passed into the BayesTree constructor!"); + if(conditionals[c->firstFrontalKey()]) + throw std::invalid_argument("An inconsistent BayesNet with duplicate frontal variables was passed into the BayesTree constructor!"); + + conditionals[c->firstFrontalKey()] = c; + } + + // Build the new tree + this->recursiveTreeBuild(sbt.root(), conditionals, sharedClique()); } /* ************************************************************************* */ - template - BayesTree::BayesTree(const BayesNet& bayesNet) { - typename BayesNet::const_reverse_iterator rit; + template<> + inline BayesTree::BayesTree(const BayesNet& bayesNet) { + typename BayesNet::const_reverse_iterator rit; for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) insert(*this, *rit); } diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 87788e3d8..06a47ba8e 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -31,6 +31,7 @@ #include #include #include +#include #include namespace gtsam { @@ -127,15 +128,22 @@ namespace gtsam { /** Fill the nodes index for a subtree */ void fillNodesIndex(const sharedClique& subtree); + /** Helper function to build a non-symbolic tree (e.g. Gaussian) using a + * symbolic tree, used in the BT(BN) constructor. + */ + void recursiveTreeBuild(const boost::shared_ptr >& symbolic, + const std::vector >& conditionals, + const typename BayesTree::sharedClique& parent); + public: /// @name Standard Constructors /// @{ /** Create an empty Bayes Tree */ - BayesTree(); + BayesTree() {} - /** Create a Bayes Tree from a Bayes Net */ + /** Create a Bayes Tree from a Bayes Net (requires CONDITIONAL is IndexConditional *or* CONDITIONAL::Combine) */ BayesTree(const BayesNet& bayesNet); /// @} diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index 08427bae4..82cec6f76 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -19,6 +19,7 @@ #include #include +#include #include #include diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index f2a77c894..7eb69b15b 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -137,6 +137,15 @@ public: /** Copy constructor */ GaussianConditional(const GaussianConditional& rhs); + /** Combine several GaussianConditional into a single dense GC. The + * conditionals enumerated by \c first and \c last must be in increasing + * order, meaning that the parents of any conditional may not include a + * conditional coming before it. + * @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr. + * @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr. */ + template + static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional); + /** Assignment operator */ GaussianConditional& operator=(const GaussianConditional& rhs); @@ -274,5 +283,49 @@ GaussianConditional::GaussianConditional(ITERATOR firstKey, ITERATOR lastKey, } /* ************************************************************************* */ +template +GaussianConditional::shared_ptr GaussianConditional::Combine(ITERATOR firstConditional, ITERATOR lastConditional) { + + // TODO: check for being a clique + + // Get dimensions from first conditional + std::vector dims; dims.reserve((*firstConditional)->size() + 1); + for(const_iterator j = (*firstConditional)->begin(); j != (*firstConditional)->end(); ++j) + dims.push_back((*firstConditional)->dim(j)); + dims.push_back(1); + + // We assume the conditionals form clique, so the first n variables will be + // frontal variables in the new conditional. + size_t nFrontals = 0; + size_t nRows = 0; + for(ITERATOR c = firstConditional; c != lastConditional; ++c) { + nRows += dims[nFrontals]; + ++ nFrontals; + } + + // Allocate combined conditional, has same keys as firstConditional + Matrix tempCombined; + VerticalBlockView tempBlockView(tempCombined, dims.begin(), dims.end(), 0); + GaussianConditional::shared_ptr combinedConditional(new GaussianConditional((*firstConditional)->begin(), (*firstConditional)->end(), nFrontals, tempBlockView, zero(nRows))); + + // Resize to correct number of rows + combinedConditional->matrix_.resize(nRows, combinedConditional->matrix_.cols()); + combinedConditional->rsd_.rowEnd() = combinedConditional->matrix_.rows(); + + // Copy matrix and sigmas + const size_t totalDims = combinedConditional->matrix_.cols(); + size_t currentSlot = 0; + for(ITERATOR c = firstConditional; c != lastConditional; ++c) { + const size_t startRow = combinedConditional->rsd_.offset(currentSlot); // Start row is same as start column + combinedConditional->rsd_.range(0, currentSlot).block(startRow, 0, dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)).operator=( + Matrix::Zero(dims[currentSlot], combinedConditional->rsd_.offset(currentSlot))); + combinedConditional->rsd_.range(currentSlot, dims.size()).block(startRow, 0, dims[currentSlot], totalDims - startRow).operator=( + (*c)->matrix_); + combinedConditional->sigmas_.segment(startRow, dims[currentSlot]) = (*c)->sigmas_; + ++ currentSlot; + } + + return combinedConditional; +} } // gtsam diff --git a/gtsam/linear/tests/testGaussianJunctionTree.cpp b/gtsam/linear/tests/testGaussianJunctionTree.cpp index 2462f5255..3a9bcb4ad 100644 --- a/gtsam/linear/tests/testGaussianJunctionTree.cpp +++ b/gtsam/linear/tests/testGaussianJunctionTree.cpp @@ -103,6 +103,21 @@ TEST( GaussianJunctionTree, eliminate ) EXPECT(assert_equal(*(bayesTree_expected.root()->children().front()), *(rootClique->children().front()))); } +/* ************************************************************************* */ +TEST_UNSAFE( GaussianJunctionTree, GBNConstructor ) +{ + GaussianFactorGraph fg = createChain(); + GaussianJunctionTree jt(fg); + BayesTree::sharedClique root = jt.eliminate(&EliminateQR); + BayesTree expected; + expected.insert(root); + + GaussianBayesNet bn(*GaussianSequentialSolver(fg).eliminate()); + BayesTree actual(bn); + + EXPECT(assert_equal(expected, actual)); +} + /* ************************************************************************* */ TEST( GaussianJunctionTree, optimizeMultiFrontal ) {