Fixed computing marginals in BayesTree
parent
ccea5c79cb
commit
96eb939749
|
@ -329,27 +329,27 @@ namespace gtsam {
|
|||
return p_S_R;
|
||||
}
|
||||
|
||||
// /* ************************************************************************* */
|
||||
// // P(C) = \int_R P(F|S) P(S|R) P(R)
|
||||
// // TODO: Maybe we should integrate given parent marginal P(Cp),
|
||||
// // \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
|
||||
// // Because the root clique could be very big.
|
||||
// /* ************************************************************************* */
|
||||
// template<class Conditional>
|
||||
// template<class Factor>
|
||||
// FactorGraph<Factor>
|
||||
// BayesTree<Conditional>::Clique::marginal(shared_ptr R) {
|
||||
// // If we are the root, just return this root
|
||||
// if (R.get()==this) return *R;
|
||||
//
|
||||
// // Combine P(F|S), P(S|R), and P(R)
|
||||
// BayesNet<Conditional> p_FSR = this->shortcut<Factor>(R);
|
||||
// p_FSR.push_front(*this);
|
||||
// p_FSR.push_back(*R);
|
||||
//
|
||||
// // Find marginal on the keys we are interested in
|
||||
// return marginalize<Factor,Conditional>(p_FSR,keys());
|
||||
// }
|
||||
/* ************************************************************************* */
|
||||
// P(C) = \int_R P(F|S) P(S|R) P(R)
|
||||
// TODO: Maybe we should integrate given parent marginal P(Cp),
|
||||
// \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
|
||||
// Because the root clique could be very big.
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
template<class FactorGraph>
|
||||
FactorGraph
|
||||
BayesTree<Conditional>::Clique::marginal(shared_ptr R) {
|
||||
// If we are the root, just return this root
|
||||
if (R.get()==this) return *R;
|
||||
|
||||
// Combine P(F|S), P(S|R), and P(R)
|
||||
BayesNet<Conditional> p_FSR = this->shortcut<FactorGraph>(R);
|
||||
p_FSR.push_front(*this);
|
||||
p_FSR.push_back(*R);
|
||||
|
||||
// Find marginal on the keys we are interested in
|
||||
return FactorGraph(*Inference::Marginal(FactorGraph(p_FSR), keys()));
|
||||
}
|
||||
|
||||
// /* ************************************************************************* */
|
||||
// // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
|
||||
|
@ -676,41 +676,51 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
// /* ************************************************************************* */
|
||||
// // First finds clique marginal then marginalizes that
|
||||
// /* ************************************************************************* */
|
||||
// template<class Conditional>
|
||||
// template<class Factor>
|
||||
// FactorGraph<Factor>
|
||||
// BayesTree<Conditional>::marginal(varid_t key) const {
|
||||
//
|
||||
// // get clique containing key
|
||||
// sharedClique clique = (*this)[key];
|
||||
//
|
||||
// // calculate or retrieve its marginal
|
||||
// FactorGraph<Factor> cliqueMarginal = clique->marginal<Factor>(root_);
|
||||
//
|
||||
// // create an ordering where only the requested key is not eliminated
|
||||
// vector<varid_t> ord = clique->keys();
|
||||
// ord.remove(key);
|
||||
//
|
||||
// // partially eliminate, remaining factor graph is requested marginal
|
||||
// eliminate<Factor,Conditional>(cliqueMarginal,ord);
|
||||
// return cliqueMarginal;
|
||||
// }
|
||||
/* ************************************************************************* */
|
||||
// First finds clique marginal then marginalizes that
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
template<class FactorGraph>
|
||||
FactorGraph
|
||||
BayesTree<Conditional>::marginal(varid_t key) const {
|
||||
|
||||
// /* ************************************************************************* */
|
||||
// template<class Conditional>
|
||||
// template<class Factor>
|
||||
// BayesNet<Conditional>
|
||||
// BayesTree<Conditional>::marginalBayesNet(varid_t key) const {
|
||||
//
|
||||
// // calculate marginal as a factor graph
|
||||
// FactorGraph<Factor> fg = this->marginal<Factor>(key);
|
||||
//
|
||||
// // eliminate further to Bayes net
|
||||
// return eliminate<Factor,Conditional>(fg,Ordering(key));
|
||||
// }
|
||||
// get clique containing key
|
||||
sharedClique clique = (*this)[key];
|
||||
|
||||
// calculate or retrieve its marginal
|
||||
FactorGraph cliqueMarginal = clique->marginal<FactorGraph>(root_);
|
||||
|
||||
// Reorder so that only the requested key is not eliminated
|
||||
typename FactorGraph::variableindex_type varIndex(cliqueMarginal);
|
||||
vector<varid_t> keyAsVector(1); keyAsVector[0] = key;
|
||||
Permutation toBack(Permutation::PushToBack(keyAsVector, varIndex.size()));
|
||||
Permutation::shared_ptr toBackInverse(toBack.inverse());
|
||||
varIndex.permute(toBack);
|
||||
BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) {
|
||||
factor->permuteWithInverse(*toBackInverse);
|
||||
}
|
||||
|
||||
// partially eliminate, remaining factor graph is requested marginal
|
||||
Inference::EliminateUntil(cliqueMarginal, varIndex.size()-1, varIndex);
|
||||
BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) {
|
||||
if(factor)
|
||||
factor->permuteWithInverse(toBack);
|
||||
}
|
||||
return cliqueMarginal;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class Conditional>
|
||||
template<class FactorGraph>
|
||||
BayesNet<Conditional>
|
||||
BayesTree<Conditional>::marginalBayesNet(varid_t key) const {
|
||||
|
||||
// calculate marginal as a factor graph
|
||||
FactorGraph fg = this->marginal<FactorGraph>(key);
|
||||
|
||||
// eliminate further to Bayes net
|
||||
return *Inference::Eliminate(fg);
|
||||
}
|
||||
|
||||
// /* ************************************************************************* */
|
||||
// // Find two cliques, their joint, then marginalizes
|
||||
|
|
|
@ -99,10 +99,10 @@ namespace gtsam {
|
|||
template<class FactorGraph>
|
||||
BayesNet<Conditional> shortcut(shared_ptr root);
|
||||
|
||||
// /** return the marginal P(C) of the clique */
|
||||
// template<class Factor>
|
||||
// FactorGraph<Factor> marginal(shared_ptr root);
|
||||
//
|
||||
/** return the marginal P(C) of the clique */
|
||||
template<class FactorGraph>
|
||||
FactorGraph marginal(shared_ptr root);
|
||||
|
||||
// /** return the joint P(C1,C2), where C1==this. TODO: not a method? */
|
||||
// template<class Factor>
|
||||
// std::pair<FactorGraph<Factor>,Ordering> joint(shared_ptr C2, shared_ptr root);
|
||||
|
@ -245,14 +245,14 @@ namespace gtsam {
|
|||
/** Gather data on all cliques */
|
||||
CliqueData getCliqueData() const;
|
||||
|
||||
// /** return marginal on any variable */
|
||||
// template<class Factor>
|
||||
// FactorGraph<Factor> marginal(varid_t key) const;
|
||||
//
|
||||
// /** return marginal on any variable, as a Bayes Net */
|
||||
// template<class Factor>
|
||||
// BayesNet<Conditional> marginalBayesNet(varid_t key) const;
|
||||
//
|
||||
/** return marginal on any variable */
|
||||
template<class FactorGraph>
|
||||
FactorGraph marginal(varid_t key) const;
|
||||
|
||||
/** return marginal on any variable, as a Bayes Net */
|
||||
template<class FactorGraph>
|
||||
BayesNet<Conditional> marginalBayesNet(varid_t key) const;
|
||||
|
||||
// /** return joint on two variables */
|
||||
// template<class Factor>
|
||||
// FactorGraph<Factor> joint(varid_t key1, varid_t key2) const;
|
||||
|
|
|
@ -30,12 +30,14 @@ Permutation Permutation::PullToFront(const vector<varid_t>& toFront, size_t size
|
|||
// Mask of which variables have been pulled, used to reorder
|
||||
vector<bool> pulled(size, false);
|
||||
|
||||
// Put the pulled variables at the front of the permutation and set up the
|
||||
// pulled flags.
|
||||
for(varid_t j=0; j<toFront.size(); ++j) {
|
||||
ret[j] = toFront[j];
|
||||
pulled[toFront[j]] = true;
|
||||
assert(toFront[j] < size);
|
||||
}
|
||||
|
||||
// Fill in the rest of the variables
|
||||
varid_t nextVar = toFront.size();
|
||||
for(varid_t j=0; j<size; ++j)
|
||||
if(!pulled[j])
|
||||
|
@ -45,6 +47,33 @@ Permutation Permutation::PullToFront(const vector<varid_t>& toFront, size_t size
|
|||
return ret;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Permutation Permutation::PushToBack(const std::vector<varid_t>& toBack, size_t size) {
|
||||
|
||||
Permutation ret(size);
|
||||
|
||||
// Mask of which variables have been pushed, used to reorder
|
||||
vector<bool> pushed(size, false);
|
||||
|
||||
// Put the pushed variables at the back of the permutation and set up the
|
||||
// pushed flags;
|
||||
varid_t nextVar = size - toBack.size();
|
||||
for(varid_t j=0; j<toBack.size(); ++j) {
|
||||
ret[nextVar++] = toBack[j];
|
||||
pushed[toBack[j]] = true;
|
||||
}
|
||||
assert(nextVar == size);
|
||||
|
||||
// Fill in the rest of the variables
|
||||
nextVar = 0;
|
||||
for(varid_t j=0; j<size; ++j)
|
||||
if(!pushed[j])
|
||||
ret[nextVar++] = j;
|
||||
assert(nextVar == size - toBack.size());
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Permutation::shared_ptr Permutation::permute(const Permutation& permutation) const {
|
||||
const size_t nVars = permutation.size();
|
||||
|
|
|
@ -87,6 +87,12 @@ public:
|
|||
*/
|
||||
static Permutation PullToFront(const std::vector<varid_t>& toFront, size_t size);
|
||||
|
||||
/**
|
||||
* Create a permutation that pulls the given variables to the front while
|
||||
* pushing the rest to the back.
|
||||
*/
|
||||
static Permutation PushToBack(const std::vector<varid_t>& toBack, size_t size);
|
||||
|
||||
iterator begin() { return rangeIndices_.begin(); }
|
||||
const_iterator begin() const { return rangeIndices_.begin(); }
|
||||
iterator end() { return rangeIndices_.end(); }
|
||||
|
|
|
@ -51,6 +51,9 @@ public:
|
|||
template<class Container>
|
||||
VectorValues(const Container& dimensions);
|
||||
|
||||
/** Construct to hold nVars vectors of varDim dimension each. */
|
||||
VectorValues(varid_t nVars, size_t varDim);
|
||||
|
||||
/** Construct from a container of variable dimensions in variable order and
|
||||
* a combined Vector of all of the variables in order.
|
||||
*/
|
||||
|
@ -179,6 +182,14 @@ inline VectorValues::VectorValues(const Container& dimensions) : varStarts_(dime
|
|||
values_.resize(varStarts_.back(), false);
|
||||
}
|
||||
|
||||
inline VectorValues::VectorValues(varid_t nVars, size_t varDim) : varStarts_(nVars+1) {
|
||||
varStarts_[0] = 0;
|
||||
size_t varStart = 0;
|
||||
for(varid_t j=1; j<=nVars; ++j)
|
||||
varStarts_[j] = (varStart += varDim);
|
||||
values_.resize(varStarts_.back(), false);
|
||||
}
|
||||
|
||||
inline VectorValues::VectorValues(const std::vector<size_t>& dimensions, const Vector& values) :
|
||||
values_(values), varStarts_(dimensions.size()+1) {
|
||||
varStarts_[0] = 0;
|
||||
|
|
|
@ -157,54 +157,53 @@ TEST( BayesTree, linear_smoother_shortcuts )
|
|||
C4 x7 : x6
|
||||
|
||||
************************************************************************* */
|
||||
// SL-FIX TEST( BayesTree, balanced_smoother_marginals )
|
||||
//{
|
||||
// // Create smoother with 7 nodes
|
||||
// GaussianFactorGraph smoother = createSmoother(7);
|
||||
// Ordering ordering;
|
||||
// ordering += "x1","x3","x5","x7","x2","x6","x4";
|
||||
//
|
||||
// // eliminate using a "nested dissection" ordering
|
||||
// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering);
|
||||
//
|
||||
// VectorValues expectedSolution;
|
||||
// BOOST_FOREACH(string key, ordering)
|
||||
// expectedSolution.insert(key,zero(2));
|
||||
// VectorValues actualSolution = optimize(chordalBayesNet);
|
||||
// CHECK(assert_equal(expectedSolution,actualSolution,tol));
|
||||
//
|
||||
// // Create the Bayes tree
|
||||
// GaussianISAM bayesTree(chordalBayesNet);
|
||||
// LONGS_EQUAL(4,bayesTree.size());
|
||||
//
|
||||
// double tol=1e-5;
|
||||
//
|
||||
// // Check marginal on x1
|
||||
// GaussianBayesNet expected1 = simpleGaussian("x1", zero(2), sigmax1);
|
||||
// GaussianBayesNet actual1 = bayesTree.marginalBayesNet<GaussianFactor>("x1");
|
||||
// CHECK(assert_equal(expected1,actual1,tol));
|
||||
//
|
||||
// // Check marginal on x2
|
||||
// double sigx2 = 0.68712938; // FIXME: this should be corrected analytically
|
||||
// GaussianBayesNet expected2 = simpleGaussian("x2", zero(2), sigx2);
|
||||
// GaussianBayesNet actual2 = bayesTree.marginalBayesNet<GaussianFactor>("x2");
|
||||
// CHECK(assert_equal(expected2,actual2,tol)); // FAILS
|
||||
//
|
||||
// // Check marginal on x3
|
||||
// GaussianBayesNet expected3 = simpleGaussian("x3", zero(2), sigmax3);
|
||||
// GaussianBayesNet actual3 = bayesTree.marginalBayesNet<GaussianFactor>("x3");
|
||||
// CHECK(assert_equal(expected3,actual3,tol));
|
||||
//
|
||||
// // Check marginal on x4
|
||||
// GaussianBayesNet expected4 = simpleGaussian("x4", zero(2), sigmax4);
|
||||
// GaussianBayesNet actual4 = bayesTree.marginalBayesNet<GaussianFactor>("x4");
|
||||
// CHECK(assert_equal(expected4,actual4,tol));
|
||||
//
|
||||
// // Check marginal on x7 (should be equal to x1)
|
||||
// GaussianBayesNet expected7 = simpleGaussian("x7", zero(2), sigmax7);
|
||||
// GaussianBayesNet actual7 = bayesTree.marginalBayesNet<GaussianFactor>("x7");
|
||||
// CHECK(assert_equal(expected7,actual7,tol));
|
||||
//}
|
||||
TEST( BayesTree, balanced_smoother_marginals )
|
||||
{
|
||||
// Create smoother with 7 nodes
|
||||
Ordering ordering;
|
||||
ordering += "x1","x3","x5","x7","x2","x6","x4";
|
||||
GaussianFactorGraph smoother = createSmoother(7, ordering).first;
|
||||
|
||||
// Create the Bayes tree
|
||||
GaussianBayesNet chordalBayesNet = *Inference::Eliminate(smoother);
|
||||
|
||||
VectorValues expectedSolution(7, 2);
|
||||
expectedSolution.makeZero();
|
||||
VectorValues actualSolution = optimize(chordalBayesNet);
|
||||
CHECK(assert_equal(expectedSolution,actualSolution,tol));
|
||||
|
||||
// Create the Bayes tree
|
||||
GaussianISAM bayesTree(chordalBayesNet);
|
||||
LONGS_EQUAL(4,bayesTree.size());
|
||||
|
||||
double tol=1e-5;
|
||||
|
||||
// Check marginal on x1
|
||||
GaussianBayesNet expected1 = simpleGaussian(ordering["x1"], zero(2), sigmax1);
|
||||
GaussianBayesNet actual1 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x1"]);
|
||||
CHECK(assert_equal(expected1,actual1,tol));
|
||||
|
||||
// Check marginal on x2
|
||||
double sigx2 = 0.68712938; // FIXME: this should be corrected analytically
|
||||
GaussianBayesNet expected2 = simpleGaussian(ordering["x2"], zero(2), sigx2);
|
||||
GaussianBayesNet actual2 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x2"]);
|
||||
CHECK(assert_equal(expected2,actual2,tol)); // FAILS
|
||||
|
||||
// Check marginal on x3
|
||||
GaussianBayesNet expected3 = simpleGaussian(ordering["x3"], zero(2), sigmax3);
|
||||
GaussianBayesNet actual3 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x3"]);
|
||||
CHECK(assert_equal(expected3,actual3,tol));
|
||||
|
||||
// Check marginal on x4
|
||||
GaussianBayesNet expected4 = simpleGaussian(ordering["x4"], zero(2), sigmax4);
|
||||
GaussianBayesNet actual4 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x4"]);
|
||||
CHECK(assert_equal(expected4,actual4,tol));
|
||||
|
||||
// Check marginal on x7 (should be equal to x1)
|
||||
GaussianBayesNet expected7 = simpleGaussian(ordering["x7"], zero(2), sigmax7);
|
||||
GaussianBayesNet actual7 = bayesTree.marginalBayesNet<GaussianFactorGraph>(ordering["x7"]);
|
||||
CHECK(assert_equal(expected7,actual7,tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( BayesTree, balanced_smoother_shortcuts )
|
||||
|
@ -238,26 +237,33 @@ TEST( BayesTree, balanced_smoother_shortcuts )
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// SL-FIX TEST( BayesTree, balanced_smoother_clique_marginals )
|
||||
//{
|
||||
// // Create smoother with 7 nodes
|
||||
// GaussianFactorGraph smoother = createSmoother(7);
|
||||
// Ordering ordering;
|
||||
// ordering += "x1","x3","x5","x7","x2","x6","x4";
|
||||
//
|
||||
// // Create the Bayes tree
|
||||
// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering);
|
||||
// GaussianISAM bayesTree(chordalBayesNet);
|
||||
//
|
||||
// // Check the clique marginal P(C3)
|
||||
// double sigmax2_alt = 1/1.45533; // THIS NEEDS TO BE CHECKED!
|
||||
// GaussianBayesNet expected = simpleGaussian("x2",zero(2),sigmax2_alt);
|
||||
// push_front(expected,"x1", zero(2), eye(2)*sqrt(2), "x2", -eye(2)*sqrt(2)/2, ones(2));
|
||||
// GaussianISAM::sharedClique R = bayesTree.root(), C3 = bayesTree["x1"];
|
||||
// FactorGraph<GaussianFactor> marginal = C3->marginal<GaussianFactor>(R);
|
||||
// GaussianBayesNet actual = eliminate<GaussianFactor,GaussianConditional>(marginal,C3->keys());
|
||||
// CHECK(assert_equal(expected,actual,tol));
|
||||
//}
|
||||
TEST( BayesTree, balanced_smoother_clique_marginals )
|
||||
{
|
||||
// Create smoother with 7 nodes
|
||||
Ordering ordering;
|
||||
ordering += "x1","x3","x5","x7","x2","x6","x4";
|
||||
GaussianFactorGraph smoother = createSmoother(7, ordering).first;
|
||||
|
||||
// Create the Bayes tree
|
||||
GaussianBayesNet chordalBayesNet = *Inference::Eliminate(smoother);
|
||||
GaussianISAM bayesTree(chordalBayesNet);
|
||||
|
||||
// Check the clique marginal P(C3)
|
||||
double sigmax2_alt = 1/1.45533; // THIS NEEDS TO BE CHECKED!
|
||||
GaussianBayesNet expected = simpleGaussian(ordering["x2"],zero(2),sigmax2_alt);
|
||||
push_front(expected,ordering["x1"], zero(2), eye(2)*sqrt(2), ordering["x2"], -eye(2)*sqrt(2)/2, ones(2));
|
||||
GaussianISAM::sharedClique R = bayesTree.root(), C3 = bayesTree[ordering["x1"]];
|
||||
GaussianFactorGraph marginal = C3->marginal<GaussianFactorGraph>(R);
|
||||
GaussianVariableIndex<> varIndex(marginal);
|
||||
Permutation toFront(Permutation::PullToFront(C3->keys(), varIndex.size()));
|
||||
Permutation toFrontInverse(*toFront.inverse());
|
||||
varIndex.permute(toFront);
|
||||
BOOST_FOREACH(const GaussianFactor::shared_ptr& factor, marginal) {
|
||||
factor->permuteWithInverse(toFrontInverse); }
|
||||
GaussianBayesNet actual = *Inference::EliminateUntil(marginal, C3->keys().size(), varIndex);
|
||||
actual.permuteWithInverse(toFront);
|
||||
CHECK(assert_equal(expected,actual,tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// SL-FIX TEST( BayesTree, balanced_smoother_joint )
|
||||
|
|
Loading…
Reference in New Issue