Generic BayesTree from BayesNet constructor, works for GaussianBayesTree(GaussianBayesNet)
parent
c842c5c9fd
commit
0531983c74
|
|
@ -234,15 +234,56 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class CONDITIONAL, class CLIQUE>
|
template<class CONDITIONAL, class CLIQUE>
|
||||||
BayesTree<CONDITIONAL,CLIQUE>::BayesTree() {
|
void BayesTree<CONDITIONAL,CLIQUE>::recursiveTreeBuild(const boost::shared_ptr<BayesTreeClique<IndexConditional> >& symbolic,
|
||||||
|
const std::vector<boost::shared_ptr<CONDITIONAL> >& conditionals,
|
||||||
|
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& parent) {
|
||||||
|
|
||||||
|
// Helper function to build a non-symbolic tree (e.g. Gaussian) using a
|
||||||
|
// symbolic tree, used in the BT(BN) constructor.
|
||||||
|
|
||||||
|
// Build the current clique
|
||||||
|
FastList<typename CONDITIONAL::shared_ptr> cliqueConditionals;
|
||||||
|
BOOST_FOREACH(Index j, symbolic->conditional()->frontals()) {
|
||||||
|
cliqueConditionals.push_back(conditionals[j]); }
|
||||||
|
typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique thisClique(new CLIQUE(CONDITIONAL::Combine(cliqueConditionals.begin(), cliqueConditionals.end())));
|
||||||
|
|
||||||
|
// Add the new clique with the current parent
|
||||||
|
this->addClique(thisClique, parent);
|
||||||
|
|
||||||
|
// Build the children, whose parent is the new clique
|
||||||
|
BOOST_FOREACH(const BayesTree<IndexConditional>::sharedClique& child, symbolic->children()) {
|
||||||
|
this->recursiveTreeBuild(child, conditionals, thisClique); }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
template<class CONDITIONAL, class CLIQUE>
|
||||||
|
BayesTree<CONDITIONAL,CLIQUE>::BayesTree(const BayesNet<CONDITIONAL>& bayesNet) {
|
||||||
|
// First generate symbolic BT to determine clique structure
|
||||||
|
BayesTree<IndexConditional> sbt(bayesNet);
|
||||||
|
|
||||||
|
// Build index of variables to conditionals
|
||||||
|
std::vector<boost::shared_ptr<CONDITIONAL> > conditionals(sbt.root()->conditional()->frontals().back() + 1);
|
||||||
|
BOOST_FOREACH(const boost::shared_ptr<CONDITIONAL>& c, bayesNet) {
|
||||||
|
if(c->nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument("BayesTree constructor from BayesNet only supports single frontal variable conditionals");
|
||||||
|
if(c->firstFrontalKey() >= conditionals.size())
|
||||||
|
throw std::invalid_argument("An inconsistent BayesNet was passed into the BayesTree constructor!");
|
||||||
|
if(conditionals[c->firstFrontalKey()])
|
||||||
|
throw std::invalid_argument("An inconsistent BayesNet with duplicate frontal variables was passed into the BayesTree constructor!");
|
||||||
|
|
||||||
|
conditionals[c->firstFrontalKey()] = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the new tree
|
||||||
|
this->recursiveTreeBuild(sbt.root(), conditionals, sharedClique());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class CONDITIONAL, class CLIQUE>
|
template<>
|
||||||
BayesTree<CONDITIONAL,CLIQUE>::BayesTree(const BayesNet<CONDITIONAL>& bayesNet) {
|
inline BayesTree<IndexConditional>::BayesTree(const BayesNet<IndexConditional>& bayesNet) {
|
||||||
typename BayesNet<CONDITIONAL>::const_reverse_iterator rit;
|
typename BayesNet<IndexConditional>::const_reverse_iterator rit;
|
||||||
for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
|
for ( rit=bayesNet.rbegin(); rit != bayesNet.rend(); ++rit )
|
||||||
insert(*this, *rit);
|
insert(*this, *rit);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||||
|
#include <gtsam/inference/IndexConditional.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -127,15 +128,22 @@ namespace gtsam {
|
||||||
/** Fill the nodes index for a subtree */
|
/** Fill the nodes index for a subtree */
|
||||||
void fillNodesIndex(const sharedClique& subtree);
|
void fillNodesIndex(const sharedClique& subtree);
|
||||||
|
|
||||||
|
/** Helper function to build a non-symbolic tree (e.g. Gaussian) using a
|
||||||
|
* symbolic tree, used in the BT(BN) constructor.
|
||||||
|
*/
|
||||||
|
void recursiveTreeBuild(const boost::shared_ptr<BayesTreeClique<IndexConditional> >& symbolic,
|
||||||
|
const std::vector<boost::shared_ptr<CONDITIONAL> >& conditionals,
|
||||||
|
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& parent);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Create an empty Bayes Tree */
|
/** Create an empty Bayes Tree */
|
||||||
BayesTree();
|
BayesTree() {}
|
||||||
|
|
||||||
/** Create a Bayes Tree from a Bayes Net */
|
/** Create a Bayes Tree from a Bayes Net (requires CONDITIONAL is IndexConditional *or* CONDITIONAL::Combine) */
|
||||||
BayesTree(const BayesNet<CONDITIONAL>& bayesNet);
|
BayesTree(const BayesNet<CONDITIONAL>& bayesNet);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/weak_ptr.hpp>
|
||||||
|
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
|
||||||
|
|
@ -137,6 +137,15 @@ public:
|
||||||
/** Copy constructor */
|
/** Copy constructor */
|
||||||
GaussianConditional(const GaussianConditional& rhs);
|
GaussianConditional(const GaussianConditional& rhs);
|
||||||
|
|
||||||
|
/** Combine several GaussianConditional into a single dense GC. The
|
||||||
|
* conditionals enumerated by \c first and \c last must be in increasing
|
||||||
|
* order, meaning that the parents of any conditional may not include a
|
||||||
|
* conditional coming before it.
|
||||||
|
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<GaussianConditional>.
|
||||||
|
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<GaussianConditional>. */
|
||||||
|
template<typename ITERATOR>
|
||||||
|
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
||||||
|
|
||||||
/** Assignment operator */
|
/** Assignment operator */
|
||||||
GaussianConditional& operator=(const GaussianConditional& rhs);
|
GaussianConditional& operator=(const GaussianConditional& rhs);
|
||||||
|
|
||||||
|
|
@ -274,5 +283,49 @@ GaussianConditional::GaussianConditional(ITERATOR firstKey, ITERATOR lastKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
template<typename ITERATOR>
|
||||||
|
GaussianConditional::shared_ptr GaussianConditional::Combine(ITERATOR firstConditional, ITERATOR lastConditional) {
|
||||||
|
|
||||||
|
// TODO: check for being a clique
|
||||||
|
|
||||||
|
// Get dimensions from first conditional
|
||||||
|
std::vector<size_t> dims; dims.reserve((*firstConditional)->size() + 1);
|
||||||
|
for(const_iterator j = (*firstConditional)->begin(); j != (*firstConditional)->end(); ++j)
|
||||||
|
dims.push_back((*firstConditional)->dim(j));
|
||||||
|
dims.push_back(1);
|
||||||
|
|
||||||
|
// We assume the conditionals form clique, so the first n variables will be
|
||||||
|
// frontal variables in the new conditional.
|
||||||
|
size_t nFrontals = 0;
|
||||||
|
size_t nRows = 0;
|
||||||
|
for(ITERATOR c = firstConditional; c != lastConditional; ++c) {
|
||||||
|
nRows += dims[nFrontals];
|
||||||
|
++ nFrontals;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate combined conditional, has same keys as firstConditional
|
||||||
|
Matrix tempCombined;
|
||||||
|
VerticalBlockView<Matrix> tempBlockView(tempCombined, dims.begin(), dims.end(), 0);
|
||||||
|
GaussianConditional::shared_ptr combinedConditional(new GaussianConditional((*firstConditional)->begin(), (*firstConditional)->end(), nFrontals, tempBlockView, zero(nRows)));
|
||||||
|
|
||||||
|
// Resize to correct number of rows
|
||||||
|
combinedConditional->matrix_.resize(nRows, combinedConditional->matrix_.cols());
|
||||||
|
combinedConditional->rsd_.rowEnd() = combinedConditional->matrix_.rows();
|
||||||
|
|
||||||
|
// Copy matrix and sigmas
|
||||||
|
const size_t totalDims = combinedConditional->matrix_.cols();
|
||||||
|
size_t currentSlot = 0;
|
||||||
|
for(ITERATOR c = firstConditional; c != lastConditional; ++c) {
|
||||||
|
const size_t startRow = combinedConditional->rsd_.offset(currentSlot); // Start row is same as start column
|
||||||
|
combinedConditional->rsd_.range(0, currentSlot).block(startRow, 0, dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)).operator=(
|
||||||
|
Matrix::Zero(dims[currentSlot], combinedConditional->rsd_.offset(currentSlot)));
|
||||||
|
combinedConditional->rsd_.range(currentSlot, dims.size()).block(startRow, 0, dims[currentSlot], totalDims - startRow).operator=(
|
||||||
|
(*c)->matrix_);
|
||||||
|
combinedConditional->sigmas_.segment(startRow, dims[currentSlot]) = (*c)->sigmas_;
|
||||||
|
++ currentSlot;
|
||||||
|
}
|
||||||
|
|
||||||
|
return combinedConditional;
|
||||||
|
}
|
||||||
|
|
||||||
} // gtsam
|
} // gtsam
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,21 @@ TEST( GaussianJunctionTree, eliminate )
|
||||||
EXPECT(assert_equal(*(bayesTree_expected.root()->children().front()), *(rootClique->children().front())));
|
EXPECT(assert_equal(*(bayesTree_expected.root()->children().front()), *(rootClique->children().front())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST_UNSAFE( GaussianJunctionTree, GBNConstructor )
|
||||||
|
{
|
||||||
|
GaussianFactorGraph fg = createChain();
|
||||||
|
GaussianJunctionTree jt(fg);
|
||||||
|
BayesTree<GaussianConditional>::sharedClique root = jt.eliminate(&EliminateQR);
|
||||||
|
BayesTree<GaussianConditional> expected;
|
||||||
|
expected.insert(root);
|
||||||
|
|
||||||
|
GaussianBayesNet bn(*GaussianSequentialSolver(fg).eliminate());
|
||||||
|
BayesTree<GaussianConditional> actual(bn);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected, actual));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( GaussianJunctionTree, optimizeMultiFrontal )
|
TEST( GaussianJunctionTree, optimizeMultiFrontal )
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue