Added marginal function to GaussianISAM, renamed and added comments to bayes tree
parent
9f4661544f
commit
d8f05f78ff
|
@ -382,6 +382,7 @@ namespace gtsam {
|
||||||
template<class CONDITIONAL>
|
template<class CONDITIONAL>
|
||||||
FactorGraph<typename CONDITIONAL::Factor> BayesTree<CONDITIONAL>::Clique::marginal(shared_ptr R) {
|
FactorGraph<typename CONDITIONAL::Factor> BayesTree<CONDITIONAL>::Clique::marginal(shared_ptr R) {
|
||||||
// If we are the root, just return this root
|
// If we are the root, just return this root
|
||||||
|
// NOTE: immediately cast to a factor graph
|
||||||
if (R.get()==this) return *R;
|
if (R.get()==this) return *R;
|
||||||
|
|
||||||
// Combine P(F|S), P(S|R), and P(R)
|
// Combine P(F|S), P(S|R), and P(R)
|
||||||
|
@ -389,9 +390,6 @@ namespace gtsam {
|
||||||
p_FSR.push_front(*this);
|
p_FSR.push_front(*this);
|
||||||
p_FSR.push_back(*R);
|
p_FSR.push_back(*R);
|
||||||
|
|
||||||
// Find marginal on the keys we are interested in
|
|
||||||
FactorGraph<typename CONDITIONAL::Factor> p_FSRfg(p_FSR);
|
|
||||||
|
|
||||||
assertInvariants();
|
assertInvariants();
|
||||||
return *GenericSequentialSolver<typename CONDITIONAL::Factor>(p_FSR).jointFactorGraph(keys());
|
return *GenericSequentialSolver<typename CONDITIONAL::Factor>(p_FSR).jointFactorGraph(keys());
|
||||||
}
|
}
|
||||||
|
@ -731,7 +729,7 @@ namespace gtsam {
|
||||||
// First finds clique marginal then marginalizes that
|
// First finds clique marginal then marginalizes that
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class CONDITIONAL>
|
template<class CONDITIONAL>
|
||||||
typename CONDITIONAL::Factor::shared_ptr BayesTree<CONDITIONAL>::marginal(Index key) const {
|
typename CONDITIONAL::Factor::shared_ptr BayesTree<CONDITIONAL>::marginalFactor(Index key) const {
|
||||||
|
|
||||||
// get clique containing key
|
// get clique containing key
|
||||||
sharedClique clique = (*this)[key];
|
sharedClique clique = (*this)[key];
|
||||||
|
@ -748,7 +746,7 @@ namespace gtsam {
|
||||||
|
|
||||||
// calculate marginal as a factor graph
|
// calculate marginal as a factor graph
|
||||||
FactorGraph<typename CONDITIONAL::Factor> fg;
|
FactorGraph<typename CONDITIONAL::Factor> fg;
|
||||||
fg.push_back(this->marginal(key));
|
fg.push_back(this->marginalFactor(key));
|
||||||
|
|
||||||
// eliminate factor graph marginal to a Bayes net
|
// eliminate factor graph marginal to a Bayes net
|
||||||
return GenericSequentialSolver<typename CONDITIONAL::Factor>(fg).eliminate();
|
return GenericSequentialSolver<typename CONDITIONAL::Factor>(fg).eliminate();
|
||||||
|
|
|
@ -262,9 +262,13 @@ namespace gtsam {
|
||||||
CliqueData getCliqueData() const;
|
CliqueData getCliqueData() const;
|
||||||
|
|
||||||
/** return marginal on any variable */
|
/** return marginal on any variable */
|
||||||
typename CONDITIONAL::Factor::shared_ptr marginal(Index key) const;
|
typename CONDITIONAL::Factor::shared_ptr marginalFactor(Index key) const;
|
||||||
|
|
||||||
/** return marginal on any variable, as a Bayes Net */
|
/**
|
||||||
|
* return marginal on any variable, as a Bayes Net
|
||||||
|
* NOTE: this function calls marginal, and then eliminates it into a Bayes Net
|
||||||
|
* This is more expensive than the above function
|
||||||
|
*/
|
||||||
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index key) const;
|
typename BayesNet<CONDITIONAL>::shared_ptr marginalBayesNet(Index key) const;
|
||||||
|
|
||||||
/** return joint on two variables */
|
/** return joint on two variables */
|
||||||
|
|
|
@ -75,7 +75,7 @@ GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::jointFactorGraph(const std::vec
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template<class FACTOR, class JUNCTIONTREE>
|
template<class FACTOR, class JUNCTIONTREE>
|
||||||
typename FACTOR::shared_ptr GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::marginalFactor(Index j) const {
|
typename FACTOR::shared_ptr GenericMultifrontalSolver<FACTOR, JUNCTIONTREE>::marginalFactor(Index j) const {
|
||||||
return eliminate()->marginal(j);
|
return eliminate()->marginalFactor(j);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,8 +24,20 @@ using namespace gtsam;
|
||||||
#include <gtsam/inference/ISAM-inl.h>
|
#include <gtsam/inference/ISAM-inl.h>
|
||||||
template class ISAM<GaussianConditional>;
|
template class ISAM<GaussianConditional>;
|
||||||
|
|
||||||
|
namespace ublas = boost::numeric::ublas;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
std::pair<Vector, Matrix> GaussianISAM::marginal(Index j) const {
|
||||||
|
GaussianFactor::shared_ptr factor = this->marginalFactor(j);
|
||||||
|
FactorGraph<GaussianFactor> graph;
|
||||||
|
graph.push_back(factor);
|
||||||
|
GaussianConditional::shared_ptr conditional = GaussianFactor::CombineAndEliminate(graph,1).first->front();
|
||||||
|
Matrix R = conditional->get_R();
|
||||||
|
return make_pair(conditional->get_d(), inverse(ublas::prod(ublas::trans(R), R)));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void optimize(const GaussianISAM::sharedClique& clique, VectorValues& result) {
|
void optimize(const GaussianISAM::sharedClique& clique, VectorValues& result) {
|
||||||
// parents are assumed to already be solved and available in result
|
// parents are assumed to already be solved and available in result
|
||||||
|
|
|
@ -69,6 +69,9 @@ public:
|
||||||
|
|
||||||
friend VectorValues optimize(const GaussianISAM&);
|
friend VectorValues optimize(const GaussianISAM&);
|
||||||
|
|
||||||
|
/** return marginal on any variable */
|
||||||
|
std::pair<Vector,Matrix> marginal(Index key) const;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// recursively optimize this conditional and all subtrees
|
// recursively optimize this conditional and all subtrees
|
||||||
|
|
|
@ -195,28 +195,58 @@ TEST( BayesTree, balanced_smoother_marginals )
|
||||||
// Check marginal on x1
|
// Check marginal on x1
|
||||||
GaussianBayesNet expected1 = simpleGaussian(ordering["x1"], zero(2), sigmax1);
|
GaussianBayesNet expected1 = simpleGaussian(ordering["x1"], zero(2), sigmax1);
|
||||||
GaussianBayesNet actual1 = *bayesTree.marginalBayesNet(ordering["x1"]);
|
GaussianBayesNet actual1 = *bayesTree.marginalBayesNet(ordering["x1"]);
|
||||||
CHECK(assert_equal(expected1,actual1,tol));
|
Matrix expectedCovarianceX1 = eye(2,2) * (sigmax1 * sigmax1);
|
||||||
|
Vector expectedMeanX1 = zero(2);
|
||||||
|
Matrix actualCovarianceX1; Vector actualMeanX1;
|
||||||
|
boost::tie(actualMeanX1, actualCovarianceX1) = bayesTree.marginal(ordering["x1"]);
|
||||||
|
EXPECT(assert_equal(expectedCovarianceX1, actualCovarianceX1, tol));
|
||||||
|
EXPECT(assert_equal(expectedMeanX1, actualMeanX1, tol));
|
||||||
|
EXPECT(assert_equal(expected1,actual1,tol));
|
||||||
|
|
||||||
// Check marginal on x2
|
// Check marginal on x2
|
||||||
double sigx2 = 0.68712938; // FIXME: this should be corrected analytically
|
double sigx2 = 0.68712938; // FIXME: this should be corrected analytically
|
||||||
GaussianBayesNet expected2 = simpleGaussian(ordering["x2"], zero(2), sigx2);
|
GaussianBayesNet expected2 = simpleGaussian(ordering["x2"], zero(2), sigx2);
|
||||||
GaussianBayesNet actual2 = *bayesTree.marginalBayesNet(ordering["x2"]);
|
GaussianBayesNet actual2 = *bayesTree.marginalBayesNet(ordering["x2"]);
|
||||||
CHECK(assert_equal(expected2,actual2,tol)); // FAILS
|
Matrix expectedCovarianceX2 = eye(2,2) * (sigx2 * sigx2);
|
||||||
|
Vector expectedMeanX2 = zero(2);
|
||||||
|
Matrix actualCovarianceX2; Vector actualMeanX2;
|
||||||
|
boost::tie(actualMeanX2, actualCovarianceX2) = bayesTree.marginal(ordering["x2"]);
|
||||||
|
EXPECT(assert_equal(expectedCovarianceX2, actualCovarianceX2, tol));
|
||||||
|
EXPECT(assert_equal(expectedMeanX2, actualMeanX2, tol));
|
||||||
|
EXPECT(assert_equal(expected2,actual2,tol));
|
||||||
|
|
||||||
// Check marginal on x3
|
// Check marginal on x3
|
||||||
GaussianBayesNet expected3 = simpleGaussian(ordering["x3"], zero(2), sigmax3);
|
GaussianBayesNet expected3 = simpleGaussian(ordering["x3"], zero(2), sigmax3);
|
||||||
GaussianBayesNet actual3 = *bayesTree.marginalBayesNet(ordering["x3"]);
|
GaussianBayesNet actual3 = *bayesTree.marginalBayesNet(ordering["x3"]);
|
||||||
CHECK(assert_equal(expected3,actual3,tol));
|
Matrix expectedCovarianceX3 = eye(2,2) * (sigmax3 * sigmax3);
|
||||||
|
Vector expectedMeanX3 = zero(2);
|
||||||
|
Matrix actualCovarianceX3; Vector actualMeanX3;
|
||||||
|
boost::tie(actualMeanX3, actualCovarianceX3) = bayesTree.marginal(ordering["x3"]);
|
||||||
|
EXPECT(assert_equal(expectedCovarianceX3, actualCovarianceX3, tol));
|
||||||
|
EXPECT(assert_equal(expectedMeanX3, actualMeanX3, tol));
|
||||||
|
EXPECT(assert_equal(expected3,actual3,tol));
|
||||||
|
|
||||||
// Check marginal on x4
|
// Check marginal on x4
|
||||||
GaussianBayesNet expected4 = simpleGaussian(ordering["x4"], zero(2), sigmax4);
|
GaussianBayesNet expected4 = simpleGaussian(ordering["x4"], zero(2), sigmax4);
|
||||||
GaussianBayesNet actual4 = *bayesTree.marginalBayesNet(ordering["x4"]);
|
GaussianBayesNet actual4 = *bayesTree.marginalBayesNet(ordering["x4"]);
|
||||||
CHECK(assert_equal(expected4,actual4,tol));
|
Matrix expectedCovarianceX4 = eye(2,2) * (sigmax4 * sigmax4);
|
||||||
|
Vector expectedMeanX4 = zero(2);
|
||||||
|
Matrix actualCovarianceX4; Vector actualMeanX4;
|
||||||
|
boost::tie(actualMeanX4, actualCovarianceX4) = bayesTree.marginal(ordering["x4"]);
|
||||||
|
EXPECT(assert_equal(expectedCovarianceX4, actualCovarianceX4, tol));
|
||||||
|
EXPECT(assert_equal(expectedMeanX4, actualMeanX4, tol));
|
||||||
|
EXPECT(assert_equal(expected4,actual4,tol));
|
||||||
|
|
||||||
// Check marginal on x7 (should be equal to x1)
|
// Check marginal on x7 (should be equal to x1)
|
||||||
GaussianBayesNet expected7 = simpleGaussian(ordering["x7"], zero(2), sigmax7);
|
GaussianBayesNet expected7 = simpleGaussian(ordering["x7"], zero(2), sigmax7);
|
||||||
GaussianBayesNet actual7 = *bayesTree.marginalBayesNet(ordering["x7"]);
|
GaussianBayesNet actual7 = *bayesTree.marginalBayesNet(ordering["x7"]);
|
||||||
CHECK(assert_equal(expected7,actual7,tol));
|
Matrix expectedCovarianceX7 = eye(2,2) * (sigmax7 * sigmax7);
|
||||||
|
Vector expectedMeanX7 = zero(2);
|
||||||
|
Matrix actualCovarianceX7; Vector actualMeanX7;
|
||||||
|
boost::tie(actualMeanX7, actualCovarianceX7) = bayesTree.marginal(ordering["x7"]);
|
||||||
|
EXPECT(assert_equal(expectedCovarianceX7, actualCovarianceX7, tol));
|
||||||
|
EXPECT(assert_equal(expectedMeanX7, actualMeanX7, tol));
|
||||||
|
EXPECT(assert_equal(expected7,actual7,tol));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue