Fixed computing marginals in BayesTree

release/4.3a0
Richard Roberts 2010-10-11 16:02:01 +00:00
parent ccea5c79cb
commit 96eb939749
6 changed files with 198 additions and 136 deletions

View File

@ -329,27 +329,27 @@ namespace gtsam {
return p_S_R;
}
// /* ************************************************************************* */
// // P(C) = \int_R P(F|S) P(S|R) P(R)
// // TODO: Maybe we should integrate given parent marginal P(Cp),
// // \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
// // Because the root clique could be very big.
// /* ************************************************************************* */
// template<class Conditional>
// template<class Factor>
// FactorGraph<Factor>
// BayesTree<Conditional>::Clique::marginal(shared_ptr R) {
// // If we are the root, just return this root
// if (R.get()==this) return *R;
//
// // Combine P(F|S), P(S|R), and P(R)
// BayesNet<Conditional> p_FSR = this->shortcut<Factor>(R);
// p_FSR.push_front(*this);
// p_FSR.push_back(*R);
//
// // Find marginal on the keys we are interested in
// return marginalize<Factor,Conditional>(p_FSR,keys());
// }
/* ************************************************************************* */
// P(C) = \int_R P(F|S) P(S|R) P(R)
// TODO: Maybe we should integrate given parent marginal P(Cp),
// \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
// Because the root clique could be very big.
/* ************************************************************************* */
template<class Conditional>
template<class FactorGraph>
FactorGraph
BayesTree<Conditional>::Clique::marginal(shared_ptr R) {
// If we are the root, just return this root
if (R.get()==this) return *R;
// Combine P(F|S), P(S|R), and P(R)
BayesNet<Conditional> p_FSR = this->shortcut<FactorGraph>(R);
p_FSR.push_front(*this);
p_FSR.push_back(*R);
// Find marginal on the keys we are interested in
return FactorGraph(*Inference::Marginal(FactorGraph(p_FSR), keys()));
}
// /* ************************************************************************* */
// // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
@ -676,41 +676,51 @@ namespace gtsam {
}
}
// /* ************************************************************************* */
// // First finds clique marginal then marginalizes that
// /* ************************************************************************* */
// template<class Conditional>
// template<class Factor>
// FactorGraph<Factor>
// BayesTree<Conditional>::marginal(varid_t key) const {
//
// // get clique containing key
// sharedClique clique = (*this)[key];
//
// // calculate or retrieve its marginal
// FactorGraph<Factor> cliqueMarginal = clique->marginal<Factor>(root_);
//
// // create an ordering where only the requested key is not eliminated
// vector<varid_t> ord = clique->keys();
// ord.remove(key);
//
// // partially eliminate, remaining factor graph is requested marginal
// eliminate<Factor,Conditional>(cliqueMarginal,ord);
// return cliqueMarginal;
// }
/* ************************************************************************* */
// First finds clique marginal then marginalizes that
/* ************************************************************************* */
template<class Conditional>
template<class FactorGraph>
FactorGraph
BayesTree<Conditional>::marginal(varid_t key) const {
// /* ************************************************************************* */
// template<class Conditional>
// template<class Factor>
// BayesNet<Conditional>
// BayesTree<Conditional>::marginalBayesNet(varid_t key) const {
//
// // calculate marginal as a factor graph
// FactorGraph<Factor> fg = this->marginal<Factor>(key);
//
// // eliminate further to Bayes net
// return eliminate<Factor,Conditional>(fg,Ordering(key));
// }
// get clique containing key
sharedClique clique = (*this)[key];
// calculate or retrieve its marginal
FactorGraph cliqueMarginal = clique->marginal<FactorGraph>(root_);
// Reorder so that only the requested key is not eliminated
typename FactorGraph::variableindex_type varIndex(cliqueMarginal);
vector<varid_t> keyAsVector(1); keyAsVector[0] = key;
Permutation toBack(Permutation::PushToBack(keyAsVector, varIndex.size()));
Permutation::shared_ptr toBackInverse(toBack.inverse());
varIndex.permute(toBack);
BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) {
factor->permuteWithInverse(*toBackInverse);
}
// partially eliminate, remaining factor graph is requested marginal
Inference::EliminateUntil(cliqueMarginal, varIndex.size()-1, varIndex);
BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) {
if(factor)
factor->permuteWithInverse(toBack);
}
return cliqueMarginal;
}
/* ************************************************************************* */
template<class Conditional>
template<class FactorGraph>
BayesNet<Conditional>
BayesTree<Conditional>::marginalBayesNet(varid_t key) const {
// calculate marginal as a factor graph
FactorGraph fg = this->marginal<FactorGraph>(key);
// eliminate further to Bayes net
return *Inference::Eliminate(fg);
}
// /* ************************************************************************* */
// // Find two cliques, their joint, then marginalizes

View File

@ -99,10 +99,10 @@ namespace gtsam {
template<class FactorGraph>
BayesNet<Conditional> shortcut(shared_ptr root);
// /** return the marginal P(C) of the clique */
// template<class Factor>
// FactorGraph<Factor> marginal(shared_ptr root);
//
/** return the marginal P(C) of the clique */
template<class FactorGraph>
FactorGraph marginal(shared_ptr root);
// /** return the joint P(C1,C2), where C1==this. TODO: not a method? */
// template<class Factor>
// std::pair<FactorGraph<Factor>,Ordering> joint(shared_ptr C2, shared_ptr root);
@ -245,14 +245,14 @@ namespace gtsam {
/** Gather data on all cliques */
CliqueData getCliqueData() const;
// /** return marginal on any variable */
// template<class Factor>
// FactorGraph<Factor> marginal(varid_t key) const;
//
// /** return marginal on any variable, as a Bayes Net */
// template<class Factor>
// BayesNet<Conditional> marginalBayesNet(varid_t key) const;
//
/** return marginal on any variable */
template<class FactorGraph>
FactorGraph marginal(varid_t key) const;
/** return marginal on any variable, as a Bayes Net */
template<class FactorGraph>
BayesNet<Conditional> marginalBayesNet(varid_t key) const;
// /** return joint on two variables */
// template<class Factor>
// FactorGraph<Factor> joint(varid_t key1, varid_t key2) const;

View File

@ -30,12 +30,14 @@ Permutation Permutation::PullToFront(const vector<varid_t>& toFront, size_t size
// Mask of which variables have been pulled, used to reorder
vector<bool> pulled(size, false);
// Put the pulled variables at the front of the permutation and set up the
// pulled flags.
for(varid_t j=0; j<toFront.size(); ++j) {
ret[j] = toFront[j];
pulled[toFront[j]] = true;
assert(toFront[j] < size);
}
// Fill in the rest of the variables
varid_t nextVar = toFront.size();
for(varid_t j=0; j<size; ++j)
if(!pulled[j])
@ -45,6 +47,33 @@ Permutation Permutation::PullToFront(const vector<varid_t>& toFront, size_t size
return ret;
}
/* ************************************************************************* */
Permutation Permutation::PushToBack(const std::vector<varid_t>& toBack, size_t size) {
Permutation ret(size);
// Mask of which variables have been pushed, used to reorder
vector<bool> pushed(size, false);
// Put the pushed variables at the back of the permutation and set up the
// pushed flags;
varid_t nextVar = size - toBack.size();
for(varid_t j=0; j<toBack.size(); ++j) {
ret[nextVar++] = toBack[j];
pushed[toBack[j]] = true;
}
assert(nextVar == size);
// Fill in the rest of the variables
nextVar = 0;
for(varid_t j=0; j<size; ++j)
if(!pushed[j])
ret[nextVar++] = j;
assert(nextVar == size - toBack.size());
return ret;
}
/* ************************************************************************* */
Permutation::shared_ptr Permutation::permute(const Permutation& permutation) const {
const size_t nVars = permutation.size();

View File

@ -87,6 +87,12 @@ public:
*/
static Permutation PullToFront(const std::vector<varid_t>& toFront, size_t size);
/**
* Create a permutation that pulls the given variables to the front while
* pushing the rest to the back.
*/
static Permutation PushToBack(const std::vector<varid_t>& toBack, size_t size);
iterator begin() { return rangeIndices_.begin(); }
const_iterator begin() const { return rangeIndices_.begin(); }
iterator end() { return rangeIndices_.end(); }

View File

@ -51,6 +51,9 @@ public:
template<class Container>
VectorValues(const Container& dimensions);
/** Construct to hold nVars vectors of varDim dimension each. */
VectorValues(varid_t nVars, size_t varDim);
/** Construct from a container of variable dimensions in variable order and
* a combined Vector of all of the variables in order.
*/
@ -179,6 +182,14 @@ inline VectorValues::VectorValues(const Container& dimensions) : varStarts_(dime
values_.resize(varStarts_.back(), false);
}
inline VectorValues::VectorValues(varid_t nVars, size_t varDim) : varStarts_(nVars+1) {
varStarts_[0] = 0;
size_t varStart = 0;
for(varid_t j=1; j<=nVars; ++j)
varStarts_[j] = (varStart += varDim);
values_.resize(varStarts_.back(), false);
}
inline VectorValues::VectorValues(const std::vector<size_t>& dimensions, const Vector& values) :
values_(values), varStarts_(dimensions.size()+1) {
varStarts_[0] = 0;

View File

@ -157,54 +157,53 @@ TEST( BayesTree, linear_smoother_shortcuts )
C4 x7 : x6
************************************************************************* */
// SL-FIX TEST( BayesTree, balanced_smoother_marginals )
//{
// // Create smoother with 7 nodes
// GaussianFactorGraph smoother = createSmoother(7);
// Ordering ordering;
// ordering += "x1","x3","x5","x7","x2","x6","x4";
//
// // eliminate using a "nested dissection" ordering
// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering);
//
// VectorValues expectedSolution;
// BOOST_FOREACH(string key, ordering)
// expectedSolution.insert(key,zero(2));
// VectorValues actualSolution = optimize(chordalBayesNet);
// CHECK(assert_equal(expectedSolution,actualSolution,tol));
//
// // Create the Bayes tree
// GaussianISAM bayesTree(chordalBayesNet);
// LONGS_EQUAL(4,bayesTree.size());
//
// double tol=1e-5;
//
// // Check marginal on x1
// GaussianBayesNet expected1 = simpleGaussian("x1", zero(2), sigmax1);
// GaussianBayesNet actual1 = bayesTree.marginalBayesNet<GaussianFactor>("x1");
// CHECK(assert_equal(expected1,actual1,tol));
//
// // Check marginal on x2
// double sigx2 = 0.68712938; // FIXME: this should be corrected analytically
// GaussianBayesNet expected2 = simpleGaussian("x2", zero(2), sigx2);
// GaussianBayesNet actual2 = bayesTree.marginalBayesNet<GaussianFactor>("x2");
// CHECK(assert_equal(expected2,actual2,tol)); // FAILS
//
// // Check marginal on x3
// GaussianBayesNet expected3 = simpleGaussian("x3", zero(2), sigmax3);
// GaussianBayesNet actual3 = bayesTree.marginalBayesNet<GaussianFactor>("x3");
// CHECK(assert_equal(expected3,actual3,tol));
//
// // Check marginal on x4
// GaussianBayesNet expected4 = simpleGaussian("x4", zero(2), sigmax4);
// GaussianBayesNet actual4 = bayesTree.marginalBayesNet<GaussianFactor>("x4");
// CHECK(assert_equal(expected4,actual4,tol));
//
// // Check marginal on x7 (should be equal to x1)
// GaussianBayesNet expected7 = simpleGaussian("x7", zero(2), sigmax7);
// GaussianBayesNet actual7 = bayesTree.marginalBayesNet<GaussianFactor>("x7");
// CHECK(assert_equal(expected7,actual7,tol));
//}
TEST( BayesTree, balanced_smoother_marginals )
{
// Create smoother with 7 nodes
Ordering ordering;
ordering += "x1","x3","x5","x7","x2","x6","x4";
GaussianFactorGraph smoother = createSmoother(7, ordering).first;
// Create the Bayes tree
GaussianBayesNet chordalBayesNet = *Inference::Eliminate(smoother);
VectorValues expectedSolution(7, 2);
expectedSolution.makeZero();
VectorValues actualSolution = optimize(chordalBayesNet);
CHECK(assert_equal(expectedSolution,actualSolution,tol));
// Create the Bayes tree
GaussianISAM bayesTree(chordalBayesNet);
LONGS_EQUAL(4,bayesTree.size());
double tol=1e-5;
// Check marginal on x1
GaussianBayesNet expected1 = simpleGaussian(ordering["x1"], zero(2), sigmax1);
GaussianBayesNet actual1 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x1"]);
CHECK(assert_equal(expected1,actual1,tol));
// Check marginal on x2
double sigx2 = 0.68712938; // FIXME: this should be corrected analytically
GaussianBayesNet expected2 = simpleGaussian(ordering["x2"], zero(2), sigx2);
GaussianBayesNet actual2 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x2"]);
CHECK(assert_equal(expected2,actual2,tol)); // FAILS
// Check marginal on x3
GaussianBayesNet expected3 = simpleGaussian(ordering["x3"], zero(2), sigmax3);
GaussianBayesNet actual3 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x3"]);
CHECK(assert_equal(expected3,actual3,tol));
// Check marginal on x4
GaussianBayesNet expected4 = simpleGaussian(ordering["x4"], zero(2), sigmax4);
GaussianBayesNet actual4 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x4"]);
CHECK(assert_equal(expected4,actual4,tol));
// Check marginal on x7 (should be equal to x1)
GaussianBayesNet expected7 = simpleGaussian(ordering["x7"], zero(2), sigmax7);
GaussianBayesNet actual7 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x7"]);
CHECK(assert_equal(expected7,actual7,tol));
}
/* ************************************************************************* */
TEST( BayesTree, balanced_smoother_shortcuts )
@ -238,26 +237,33 @@ TEST( BayesTree, balanced_smoother_shortcuts )
}
/* ************************************************************************* */
// SL-FIX TEST( BayesTree, balanced_smoother_clique_marginals )
//{
// // Create smoother with 7 nodes
// GaussianFactorGraph smoother = createSmoother(7);
// Ordering ordering;
// ordering += "x1","x3","x5","x7","x2","x6","x4";
//
// // Create the Bayes tree
// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering);
// GaussianISAM bayesTree(chordalBayesNet);
//
// // Check the clique marginal P(C3)
// double sigmax2_alt = 1/1.45533; // THIS NEEDS TO BE CHECKED!
// GaussianBayesNet expected = simpleGaussian("x2",zero(2),sigmax2_alt);
// push_front(expected,"x1", zero(2), eye(2)*sqrt(2), "x2", -eye(2)*sqrt(2)/2, ones(2));
// GaussianISAM::sharedClique R = bayesTree.root(), C3 = bayesTree["x1"];
// FactorGraph<GaussianFactor> marginal = C3->marginal<GaussianFactor>(R);
// GaussianBayesNet actual = eliminate<GaussianFactor,GaussianConditional>(marginal,C3->keys());
// CHECK(assert_equal(expected,actual,tol));
//}
TEST( BayesTree, balanced_smoother_clique_marginals )
{
// Create smoother with 7 nodes
Ordering ordering;
ordering += "x1","x3","x5","x7","x2","x6","x4";
GaussianFactorGraph smoother = createSmoother(7, ordering).first;
// Create the Bayes tree
GaussianBayesNet chordalBayesNet = *Inference::Eliminate(smoother);
GaussianISAM bayesTree(chordalBayesNet);
// Check the clique marginal P(C3)
double sigmax2_alt = 1/1.45533; // THIS NEEDS TO BE CHECKED!
GaussianBayesNet expected = simpleGaussian(ordering["x2"],zero(2),sigmax2_alt);
push_front(expected,ordering["x1"], zero(2), eye(2)*sqrt(2), ordering["x2"], -eye(2)*sqrt(2)/2, ones(2));
GaussianISAM::sharedClique R = bayesTree.root(), C3 = bayesTree[ordering["x1"]];
GaussianFactorGraph marginal = C3->marginal<GaussianFactorGraph>(R);
GaussianVariableIndex<> varIndex(marginal);
Permutation toFront(Permutation::PullToFront(C3->keys(), varIndex.size()));
Permutation toFrontInverse(*toFront.inverse());
varIndex.permute(toFront);
BOOST_FOREACH(const GaussianFactor::shared_ptr& factor, marginal) {
factor->permuteWithInverse(toFrontInverse); }
GaussianBayesNet actual = *Inference::EliminateUntil(marginal, C3->keys().size(), varIndex);
actual.permuteWithInverse(toFront);
CHECK(assert_equal(expected,actual,tol));
}
/* ************************************************************************* */
// SL-FIX TEST( BayesTree, balanced_smoother_joint )