Generic BayesTree from BayesNet constructor, works for GaussianBayesTree(GaussianBayesNet)
parent
c842c5c9fd
commit
0531983c74
|
|
@ -236,13 +236,54 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template<class CONDITIONAL, class CLIQUE>
|
||||
BayesTree<CONDITIONAL,CLIQUE>::BayesTree() {
|
||||
void BayesTree<CONDITIONAL,CLIQUE>::recursiveTreeBuild(const boost::shared_ptr<BayesTreeClique<IndexConditional> >& symbolic,
|
||||
const std::vector<boost::shared_ptr<CONDITIONAL> >& conditionals,
|
||||
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& parent) {
|
||||
|
||||
// Helper function to build a non-symbolic tree (e.g. Gaussian) using a
|
||||
// symbolic tree, used in the BT(BN) constructor.
|
||||
|
||||
// Build the current clique
|
||||
FastList<typename CONDITIONAL::shared_ptr> cliqueConditionals;
|
||||
BOOST_FOREACH(Index j, symbolic->conditional()->frontals()) {
|
||||
cliqueConditionals.push_back(conditionals[j]); }
|
||||
typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique thisClique(new CLIQUE(CONDITIONAL::Combine(cliqueConditionals.begin(), cliqueConditionals.end())));
|
||||
|
||||
// Add the new clique with the current parent
|
||||
this->addClique(thisClique, parent);
|
||||
|
||||
// Build the children, whose parent is the new clique
|
||||
BOOST_FOREACH(const BayesTree<IndexConditional>::sharedClique& child, symbolic->children()) {
|
||||
this->recursiveTreeBuild(child, conditionals, thisClique); }
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class CONDITIONAL, class CLIQUE>
|
||||
BayesTree<CONDITIONAL,CLIQUE>::BayesTree(const BayesNet<CONDITIONAL>& bayesNet) {
|
||||
typename BayesNet<CONDITIONAL>::const_reverse_iterator rit;
|
||||
// First generate symbolic BT to determine clique structure
|
||||
BayesTree<IndexConditional> sbt(bayesNet);
|
||||
|
||||
// Build index of variables to conditionals
|
||||
std::vector<boost::shared_ptr<CONDITIONAL> > conditionals(sbt.root()->conditional()->frontals().back() + 1);
|
||||
BOOST_FOREACH(const boost::shared_ptr<CONDITIONAL>& c, bayesNet) {
|
||||
if(c->nrFrontals() != 1)
|
||||
throw std::invalid_argument("BayesTree constructor from BayesNet only supports single frontal variable conditionals");
|
||||
if(c->firstFrontalKey() >= conditionals.size())
|
||||
throw std::invalid_argument("An inconsistent BayesNet was passed into the BayesTree constructor!");
|
||||
if(conditionals[c->firstFrontalKey()])
|
||||
throw std::invalid_argument("An inconsistent BayesNet with duplicate frontal variables was passed into the BayesTree constructor!");
|
||||
|
||||
conditionals[c->firstFrontalKey()] = c;
|
||||
}
|
||||
|
||||
// Build the new tree
|
||||
this->recursiveTreeBuild(sbt.root(), conditionals, sharedClique());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<>
|
||||
inline BayesTree<IndexConditional>::BayesTree(const BayesNet<IndexConditional>& bayesNet) {
|
||||
typename BayesNet<IndexConditional>::const_reverse_iterator rit;
|
||||
for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
|
||||
insert(*this, *rit);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@
|
|||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||
#include <gtsam/inference/IndexConditional.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
@ -127,15 +128,22 @@ namespace gtsam {
|
|||
/** Fill the nodes index for a subtree */
|
||||
void fillNodesIndex(const sharedClique& subtree);
|
||||
|
||||
/** Helper function to build a non-symbolic tree (e.g. Gaussian) using a
|
||||
* symbolic tree, used in the BT(BN) constructor.
|
||||
*/
|
||||
void recursiveTreeBuild(const boost::shared_ptr<BayesTreeClique<IndexConditional> >& symbolic,
|
||||
const std::vector<boost::shared_ptr<CONDITIONAL> >& conditionals,
|
||||
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& parent);
|
||||
|
||||
public:
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Create an empty Bayes Tree */
|
||||
BayesTree();
|
||||
BayesTree() {}
|
||||
|
||||
/** Create a Bayes Tree from a Bayes Net */
|
||||
/** Create a Bayes Tree from a Bayes Net (requires CONDITIONAL is IndexConditional *or* CONDITIONAL::Combine) */
|
||||
BayesTree(const BayesNet<CONDITIONAL>& bayesNet);
|
||||
|
||||
/// @}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/weak_ptr.hpp>
|
||||
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
|
|
|||
|
|
@ -137,6 +137,15 @@ public:
|
|||
/** Copy constructor */
|
||||
GaussianConditional(const GaussianConditional& rhs);
|
||||
|
||||
/** Combine several GaussianConditional into a single dense GC. The
|
||||
* conditionals enumerated by \c first and \c last must be in increasing
|
||||
* order, meaning that the parents of any conditional may not include a
|
||||
* conditional coming before it.
|
||||
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<GaussianConditional>.
|
||||
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<GaussianConditional>. */
|
||||
template<typename ITERATOR>
|
||||
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
||||
|
||||
/** Assignment operator */
|
||||
GaussianConditional& operator=(const GaussianConditional& rhs);
|
||||
|
||||
|
|
@ -274,5 +283,49 @@ GaussianConditional::GaussianConditional(ITERATOR firstKey, ITERATOR lastKey,
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<typename ITERATOR>
|
||||
GaussianConditional::shared_ptr GaussianConditional::Combine(ITERATOR firstConditional, ITERATOR lastConditional) {
|
||||
|
||||
// TODO: check for being a clique
|
||||
|
||||
// Get dimensions from first conditional
|
||||
std::vector<size_t> dims; dims.reserve((*firstConditional)->size() + 1);
|
||||
for(const_iterator j = (*firstConditional)->begin(); j != (*firstConditional)->end(); ++j)
|
||||
dims.push_back((*firstConditional)->dim(j));
|
||||
dims.push_back(1);
|
||||
|
||||
// We assume the conditionals form clique, so the first n variables will be
|
||||
// frontal variables in the new conditional.
|
||||
size_t nFrontals = 0;
|
||||
size_t nRows = 0;
|
||||
for(ITERATOR c = firstConditional; c != lastConditional; ++c) {
|
||||
nRows += dims[nFrontals];
|
||||
++ nFrontals;
|
||||
}
|
||||
|
||||
// Allocate combined conditional, has same keys as firstConditional
|
||||
Matrix tempCombined;
|
||||
VerticalBlockView<Matrix> tempBlockView(tempCombined, dims.begin(), dims.end(), 0);
|
||||
GaussianConditional::shared_ptr combinedConditional(new GaussianConditional((*firstConditional)->begin(), (*firstConditional)->end(), nFrontals, tempBlockView, zero(nRows)));
|
||||
|
||||
// Resize to correct number of rows
|
||||
combinedConditional->matrix_.resize(nRows, combinedConditional->matrix_.cols());
|
||||
combinedConditional->rsd_.rowEnd() = combinedConditional->matrix_.rows();
|
||||
|
||||
// Copy matrix and sigmas
|
||||
const size_t totalDims = combinedConditional->matrix_.cols();
|
||||
size_t currentSlot = 0;
|
||||
for(ITERATOR c = firstConditional; c != lastConditional; ++c) {
|
||||
const size_t startRow = combinedConditional->rsd_.offset(currentSlot); // Start row is same as start column
|
||||
combinedConditional->rsd_.range(0, currentSlot).block(startRow, 0, dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)).operator=(
|
||||
Matrix::Zero(dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)));
|
||||
combinedConditional->rsd_.range(currentSlot, dims.size()).block(startRow, 0, dims[currentSlot], totalDims - startRow).operator=(
|
||||
(*c)->matrix_);
|
||||
combinedConditional->sigmas_.segment(startRow, dims[currentSlot]) = (*c)->sigmas_;
|
||||
++ currentSlot;
|
||||
}
|
||||
|
||||
return combinedConditional;
|
||||
}
|
||||
|
||||
} // gtsam
|
||||
|
|
|
|||
|
|
@ -103,6 +103,21 @@ TEST( GaussianJunctionTree, eliminate )
|
|||
EXPECT(assert_equal(*(bayesTree_expected.root()->children().front()), *(rootClique->children().front())));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST_UNSAFE( GaussianJunctionTree, GBNConstructor )
|
||||
{
|
||||
GaussianFactorGraph fg = createChain();
|
||||
GaussianJunctionTree jt(fg);
|
||||
BayesTree<GaussianConditional>::sharedClique root = jt.eliminate(&EliminateQR);
|
||||
BayesTree<GaussianConditional> expected;
|
||||
expected.insert(root);
|
||||
|
||||
GaussianBayesNet bn(*GaussianSequentialSolver(fg).eliminate());
|
||||
BayesTree<GaussianConditional> actual(bn);
|
||||
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( GaussianJunctionTree, optimizeMultiFrontal )
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue