Generic BayesTree from BayesNet constructor, works for GaussianBayesTree(GaussianBayesNet)

release/4.3a0
Richard Roberts 2012-03-12 22:24:28 +00:00
parent c842c5c9fd
commit 0531983c74
5 changed files with 125 additions and 7 deletions

View File

@ -234,15 +234,56 @@ namespace gtsam {
} }
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL, class CLIQUE> template<class CONDITIONAL, class CLIQUE>
BayesTree<CONDITIONAL,CLIQUE>::BayesTree() { void BayesTree<CONDITIONAL,CLIQUE>::recursiveTreeBuild(const boost::shared_ptr<BayesTreeClique<IndexConditional> >& symbolic,
const std::vector<boost::shared_ptr<CONDITIONAL> >& conditionals,
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& parent) {
// Helper function to build a non-symbolic tree (e.g. Gaussian) using a
// symbolic tree, used in the BT(BN) constructor.
// Build the current clique
FastList<typename CONDITIONAL::shared_ptr> cliqueConditionals;
BOOST_FOREACH(Index j, symbolic->conditional()->frontals()) {
cliqueConditionals.push_back(conditionals[j]); }
typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique thisClique(new CLIQUE(CONDITIONAL::Combine(cliqueConditionals.begin(), cliqueConditionals.end())));
// Add the new clique with the current parent
this->addClique(thisClique, parent);
// Build the children, whose parent is the new clique
BOOST_FOREACH(const BayesTree<IndexConditional>::sharedClique& child, symbolic->children()) {
this->recursiveTreeBuild(child, conditionals, thisClique); }
}
/* ************************************************************************* */
template<class CONDITIONAL, class CLIQUE>
BayesTree<CONDITIONAL,CLIQUE>::BayesTree(const BayesNet<CONDITIONAL>& bayesNet) {
// First generate symbolic BT to determine clique structure
BayesTree<IndexConditional> sbt(bayesNet);
// Build index of variables to conditionals
std::vector<boost::shared_ptr<CONDITIONAL> > conditionals(sbt.root()->conditional()->frontals().back() + 1);
BOOST_FOREACH(const boost::shared_ptr<CONDITIONAL>& c, bayesNet) {
if(c->nrFrontals() != 1)
throw std::invalid_argument("BayesTree constructor from BayesNet only supports single frontal variable conditionals");
if(c->firstFrontalKey() >= conditionals.size())
throw std::invalid_argument("An inconsistent BayesNet was passed into the BayesTree constructor!");
if(conditionals[c->firstFrontalKey()])
throw std::invalid_argument("An inconsistent BayesNet with duplicate frontal variables was passed into the BayesTree constructor!");
conditionals[c->firstFrontalKey()] = c;
}
// Build the new tree
this->recursiveTreeBuild(sbt.root(), conditionals, sharedClique());
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<class CONDITIONAL, class CLIQUE> template<>
BayesTree<CONDITIONAL,CLIQUE>::BayesTree(const BayesNet<CONDITIONAL>& bayesNet) { inline BayesTree<IndexConditional>::BayesTree(const BayesNet<IndexConditional>& bayesNet) {
typename BayesNet<CONDITIONAL>::const_reverse_iterator rit; typename BayesNet<IndexConditional>::const_reverse_iterator rit;
for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit ) for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
insert(*this, *rit); insert(*this, *rit);
} }

View File

@ -31,6 +31,7 @@
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/BayesTreeCliqueBase.h> #include <gtsam/inference/BayesTreeCliqueBase.h>
#include <gtsam/inference/IndexConditional.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/linear/VectorValues.h>
namespace gtsam { namespace gtsam {
@ -127,15 +128,22 @@ namespace gtsam {
/** Fill the nodes index for a subtree */ /** Fill the nodes index for a subtree */
void fillNodesIndex(const sharedClique& subtree); void fillNodesIndex(const sharedClique& subtree);
/** Helper function to build a non-symbolic tree (e.g. Gaussian) using a
* symbolic tree, used in the BT(BN) constructor.
*/
void recursiveTreeBuild(const boost::shared_ptr<BayesTreeClique<IndexConditional> >& symbolic,
const std::vector<boost::shared_ptr<CONDITIONAL> >& conditionals,
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& parent);
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Create an empty Bayes Tree */ /** Create an empty Bayes Tree */
BayesTree(); BayesTree() {}
/** Create a Bayes Tree from a Bayes Net */ /** Create a Bayes Tree from a Bayes Net (requires CONDITIONAL is IndexConditional *or* CONDITIONAL::Combine) */
BayesTree(const BayesNet<CONDITIONAL>& bayesNet); BayesTree(const BayesNet<CONDITIONAL>& bayesNet);
/// @} /// @}

View File

@ -19,6 +19,7 @@
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/weak_ptr.hpp>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>

View File

@ -137,6 +137,15 @@ public:
/** Copy constructor */ /** Copy constructor */
GaussianConditional(const GaussianConditional& rhs); GaussianConditional(const GaussianConditional& rhs);
/** Combine several GaussianConditional into a single dense GC. The
* conditionals enumerated by \c first and \c last must be in increasing
* order, meaning that the parents of any conditional may not include a
* conditional coming before it.
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<GaussianConditional>.
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<GaussianConditional>. */
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
/** Assignment operator */ /** Assignment operator */
GaussianConditional& operator=(const GaussianConditional& rhs); GaussianConditional& operator=(const GaussianConditional& rhs);
@ -274,5 +283,49 @@ GaussianConditional::GaussianConditional(ITERATOR firstKey, ITERATOR lastKey,
} }
/* ************************************************************************* */ /* ************************************************************************* */
template<typename ITERATOR>
GaussianConditional::shared_ptr GaussianConditional::Combine(ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique
// Get dimensions from first conditional
std::vector<size_t> dims; dims.reserve((*firstConditional)->size() + 1);
for(const_iterator j = (*firstConditional)->begin(); j != (*firstConditional)->end(); ++j)
dims.push_back((*firstConditional)->dim(j));
dims.push_back(1);
// We assume the conditionals form clique, so the first n variables will be
// frontal variables in the new conditional.
size_t nFrontals = 0;
size_t nRows = 0;
for(ITERATOR c = firstConditional; c != lastConditional; ++c) {
nRows += dims[nFrontals];
++ nFrontals;
}
// Allocate combined conditional, has same keys as firstConditional
Matrix tempCombined;
VerticalBlockView<Matrix> tempBlockView(tempCombined, dims.begin(), dims.end(), 0);
GaussianConditional::shared_ptr combinedConditional(new GaussianConditional((*firstConditional)->begin(), (*firstConditional)->end(), nFrontals, tempBlockView, zero(nRows)));
// Resize to correct number of rows
combinedConditional->matrix_.resize(nRows, combinedConditional->matrix_.cols());
combinedConditional->rsd_.rowEnd() = combinedConditional->matrix_.rows();
// Copy matrix and sigmas
const size_t totalDims = combinedConditional->matrix_.cols();
size_t currentSlot = 0;
for(ITERATOR c = firstConditional; c != lastConditional; ++c) {
const size_t startRow = combinedConditional->rsd_.offset(currentSlot); // Start row is same as start column
combinedConditional->rsd_.range(0, currentSlot).block(startRow, 0, dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)).operator=(
Matrix::Zero(dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)));
combinedConditional->rsd_.range(currentSlot, dims.size()).block(startRow, 0, dims[currentSlot], totalDims - startRow).operator=(
(*c)->matrix_);
combinedConditional->sigmas_.segment(startRow, dims[currentSlot]) = (*c)->sigmas_;
++ currentSlot;
}
return combinedConditional;
}
} // gtsam } // gtsam

View File

@ -103,6 +103,21 @@ TEST( GaussianJunctionTree, eliminate )
EXPECT(assert_equal(*(bayesTree_expected.root()->children().front()), *(rootClique->children().front()))); EXPECT(assert_equal(*(bayesTree_expected.root()->children().front()), *(rootClique->children().front())));
} }
/* ************************************************************************* */
TEST_UNSAFE( GaussianJunctionTree, GBNConstructor )
{
GaussianFactorGraph fg = createChain();
GaussianJunctionTree jt(fg);
BayesTree<GaussianConditional>::sharedClique root = jt.eliminate(&EliminateQR);
BayesTree<GaussianConditional> expected;
expected.insert(root);
GaussianBayesNet bn(*GaussianSequentialSolver(fg).eliminate());
BayesTree<GaussianConditional> actual(bn);
EXPECT(assert_equal(expected, actual));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST( GaussianJunctionTree, optimizeMultiFrontal ) TEST( GaussianJunctionTree, optimizeMultiFrontal )
{ {