diff --git a/inference/BayesTree-inl.h b/inference/BayesTree-inl.h index 22c12e00b..7198662cf 100644 --- a/inference/BayesTree-inl.h +++ b/inference/BayesTree-inl.h @@ -329,27 +329,27 @@ namespace gtsam { return p_S_R; } -// /* ************************************************************************* */ -// // P(C) = \int_R P(F|S) P(S|R) P(R) -// // TODO: Maybe we should integrate given parent marginal P(Cp), -// // \int(Cp\S) P(F|S)P(S|Cp)P(Cp) -// // Because the root clique could be very big. -// /* ************************************************************************* */ -// template -// template -// FactorGraph -// BayesTree::Clique::marginal(shared_ptr R) { -// // If we are the root, just return this root -// if (R.get()==this) return *R; -// -// // Combine P(F|S), P(S|R), and P(R) -// BayesNet p_FSR = this->shortcut(R); -// p_FSR.push_front(*this); -// p_FSR.push_back(*R); -// -// // Find marginal on the keys we are interested in -// return marginalize(p_FSR,keys()); -// } + /* ************************************************************************* */ + // P(C) = \int_R P(F|S) P(S|R) P(R) + // TODO: Maybe we should integrate given parent marginal P(Cp), + // \int(Cp\S) P(F|S)P(S|Cp)P(Cp) + // Because the root clique could be very big. + /* ************************************************************************* */ + template + template + FactorGraph + BayesTree::Clique::marginal(shared_ptr R) { + // If we are the root, just return this root + if (R.get()==this) return *R; + + // Combine P(F|S), P(S|R), and P(R) + BayesNet p_FSR = this->shortcut(R); + p_FSR.push_front(*this); + p_FSR.push_back(*R); + + // Find marginal on the keys we are interested in + return FactorGraph(*Inference::Marginal(FactorGraph(p_FSR), keys())); + } // /* ************************************************************************* */ // // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) @@ -676,41 +676,51 @@ namespace gtsam { } } -// /* ************************************************************************* */ -// // First finds clique marginal then marginalizes that -// /* ************************************************************************* */ -// template -// template -// FactorGraph -// BayesTree::marginal(varid_t key) const { -// -// // get clique containing key -// sharedClique clique = (*this)[key]; -// -// // calculate or retrieve its marginal -// FactorGraph cliqueMarginal = clique->marginal(root_); -// -// // create an ordering where only the requested key is not eliminated -// vector ord = clique->keys(); -// ord.remove(key); -// -// // partially eliminate, remaining factor graph is requested marginal -// eliminate(cliqueMarginal,ord); -// return cliqueMarginal; -// } + /* ************************************************************************* */ + // First finds clique marginal then marginalizes that + /* ************************************************************************* */ + template + template + FactorGraph + BayesTree::marginal(varid_t key) const { -// /* ************************************************************************* */ -// template -// template -// BayesNet -// BayesTree::marginalBayesNet(varid_t key) const { -// -// // calculate marginal as a factor graph -// FactorGraph fg = this->marginal(key); -// -// // eliminate further to Bayes net -// return eliminate(fg,Ordering(key)); -// } + // get clique containing key + sharedClique clique = (*this)[key]; + + // calculate or retrieve its marginal + FactorGraph cliqueMarginal = clique->marginal(root_); + + // Reorder so that only the requested key is not eliminated + typename FactorGraph::variableindex_type varIndex(cliqueMarginal); + vector keyAsVector(1); keyAsVector[0] = key; + Permutation toBack(Permutation::PushToBack(keyAsVector, varIndex.size())); + Permutation::shared_ptr toBackInverse(toBack.inverse()); + varIndex.permute(toBack); + BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) { + factor->permuteWithInverse(*toBackInverse); + } + + // partially eliminate, remaining factor graph is requested marginal + Inference::EliminateUntil(cliqueMarginal, varIndex.size()-1, varIndex); + BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) { + if(factor) + factor->permuteWithInverse(toBack); + } + return cliqueMarginal; + } + + /* ************************************************************************* */ + template + template + BayesNet + BayesTree::marginalBayesNet(varid_t key) const { + + // calculate marginal as a factor graph + FactorGraph fg = this->marginal(key); + + // eliminate further to Bayes net + return *Inference::Eliminate(fg); + } // /* ************************************************************************* */ // // Find two cliques, their joint, then marginalizes diff --git a/inference/BayesTree.h b/inference/BayesTree.h index 707641164..6523dd14d 100644 --- a/inference/BayesTree.h +++ b/inference/BayesTree.h @@ -99,10 +99,10 @@ namespace gtsam { template BayesNet shortcut(shared_ptr root); -// /** return the marginal P(C) of the clique */ -// template -// FactorGraph marginal(shared_ptr root); -// + /** return the marginal P(C) of the clique */ + template + FactorGraph marginal(shared_ptr root); + // /** return the joint P(C1,C2), where C1==this. TODO: not a method? */ // template // std::pair,Ordering> joint(shared_ptr C2, shared_ptr root); @@ -245,14 +245,14 @@ namespace gtsam { /** Gather data on all cliques */ CliqueData getCliqueData() const; -// /** return marginal on any variable */ -// template -// FactorGraph marginal(varid_t key) const; -// -// /** return marginal on any variable, as a Bayes Net */ -// template -// BayesNet marginalBayesNet(varid_t key) const; -// + /** return marginal on any variable */ + template + FactorGraph marginal(varid_t key) const; + + /** return marginal on any variable, as a Bayes Net */ + template + BayesNet marginalBayesNet(varid_t key) const; + // /** return joint on two variables */ // template // FactorGraph joint(varid_t key1, varid_t key2) const; diff --git a/inference/Permutation.cpp b/inference/Permutation.cpp index b88a3fe63..af3b9eb30 100644 --- a/inference/Permutation.cpp +++ b/inference/Permutation.cpp @@ -30,12 +30,14 @@ Permutation Permutation::PullToFront(const vector& toFront, size_t size // Mask of which variables have been pulled, used to reorder vector pulled(size, false); + // Put the pulled variables at the front of the permutation and set up the + // pulled flags. for(varid_t j=0; j& toFront, size_t size return ret; } +/* ************************************************************************* */ +Permutation Permutation::PushToBack(const std::vector& toBack, size_t size) { + + Permutation ret(size); + + // Mask of which variables have been pushed, used to reorder + vector pushed(size, false); + + // Put the pushed variables at the back of the permutation and set up the + // pushed flags; + varid_t nextVar = size - toBack.size(); + for(varid_t j=0; j& toFront, size_t size); + /** + * Create a permutation that pulls the given variables to the front while + * pushing the rest to the back. + */ + static Permutation PushToBack(const std::vector& toBack, size_t size); + iterator begin() { return rangeIndices_.begin(); } const_iterator begin() const { return rangeIndices_.begin(); } iterator end() { return rangeIndices_.end(); } diff --git a/linear/VectorValues.h b/linear/VectorValues.h index 0b7725340..f6d2130ef 100644 --- a/linear/VectorValues.h +++ b/linear/VectorValues.h @@ -51,6 +51,9 @@ public: template VectorValues(const Container& dimensions); + /** Construct to hold nVars vectors of varDim dimension each. */ + VectorValues(varid_t nVars, size_t varDim); + /** Construct from a container of variable dimensions in variable order and * a combined Vector of all of the variables in order. */ @@ -179,6 +182,14 @@ inline VectorValues::VectorValues(const Container& dimensions) : varStarts_(dime values_.resize(varStarts_.back(), false); } +inline VectorValues::VectorValues(varid_t nVars, size_t varDim) : varStarts_(nVars+1) { + varStarts_[0] = 0; + size_t varStart = 0; + for(varid_t j=1; j<=nVars; ++j) + varStarts_[j] = (varStart += varDim); + values_.resize(varStarts_.back(), false); +} + inline VectorValues::VectorValues(const std::vector& dimensions, const Vector& values) : values_(values), varStarts_(dimensions.size()+1) { varStarts_[0] = 0; diff --git a/tests/testGaussianISAM.cpp b/tests/testGaussianISAM.cpp index db96dc507..0f6a083b7 100644 --- a/tests/testGaussianISAM.cpp +++ b/tests/testGaussianISAM.cpp @@ -157,54 +157,53 @@ TEST( BayesTree, linear_smoother_shortcuts ) C4 x7 : x6 ************************************************************************* */ -// SL-FIX TEST( BayesTree, balanced_smoother_marginals ) -//{ -// // Create smoother with 7 nodes -// GaussianFactorGraph smoother = createSmoother(7); -// Ordering ordering; -// ordering += "x1","x3","x5","x7","x2","x6","x4"; -// -// // eliminate using a "nested dissection" ordering -// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering); -// -// VectorValues expectedSolution; -// BOOST_FOREACH(string key, ordering) -// expectedSolution.insert(key,zero(2)); -// VectorValues actualSolution = optimize(chordalBayesNet); -// CHECK(assert_equal(expectedSolution,actualSolution,tol)); -// -// // Create the Bayes tree -// GaussianISAM bayesTree(chordalBayesNet); -// LONGS_EQUAL(4,bayesTree.size()); -// -// double tol=1e-5; -// -// // Check marginal on x1 -// GaussianBayesNet expected1 = simpleGaussian("x1", zero(2), sigmax1); -// GaussianBayesNet actual1 = bayesTree.marginalBayesNet("x1"); -// CHECK(assert_equal(expected1,actual1,tol)); -// -// // Check marginal on x2 -// double sigx2 = 0.68712938; // FIXME: this should be corrected analytically -// GaussianBayesNet expected2 = simpleGaussian("x2", zero(2), sigx2); -// GaussianBayesNet actual2 = bayesTree.marginalBayesNet("x2"); -// CHECK(assert_equal(expected2,actual2,tol)); // FAILS -// -// // Check marginal on x3 -// GaussianBayesNet expected3 = simpleGaussian("x3", zero(2), sigmax3); -// GaussianBayesNet actual3 = bayesTree.marginalBayesNet("x3"); -// CHECK(assert_equal(expected3,actual3,tol)); -// -// // Check marginal on x4 -// GaussianBayesNet expected4 = simpleGaussian("x4", zero(2), sigmax4); -// GaussianBayesNet actual4 = bayesTree.marginalBayesNet("x4"); -// CHECK(assert_equal(expected4,actual4,tol)); -// -// // Check marginal on x7 (should be equal to x1) -// GaussianBayesNet expected7 = simpleGaussian("x7", zero(2), sigmax7); -// GaussianBayesNet actual7 = bayesTree.marginalBayesNet("x7"); -// CHECK(assert_equal(expected7,actual7,tol)); -//} +TEST( BayesTree, balanced_smoother_marginals ) +{ + // Create smoother with 7 nodes + Ordering ordering; + ordering += "x1","x3","x5","x7","x2","x6","x4"; + GaussianFactorGraph smoother = createSmoother(7, ordering).first; + + // Create the Bayes tree + GaussianBayesNet chordalBayesNet = *Inference::Eliminate(smoother); + + VectorValues expectedSolution(7, 2); + expectedSolution.makeZero(); + VectorValues actualSolution = optimize(chordalBayesNet); + CHECK(assert_equal(expectedSolution,actualSolution,tol)); + + // Create the Bayes tree + GaussianISAM bayesTree(chordalBayesNet); + LONGS_EQUAL(4,bayesTree.size()); + + double tol=1e-5; + + // Check marginal on x1 + GaussianBayesNet expected1 = simpleGaussian(ordering["x1"], zero(2), sigmax1); + GaussianBayesNet actual1 = bayesTree.marginalBayesNet(ordering["x1"]); + CHECK(assert_equal(expected1,actual1,tol)); + + // Check marginal on x2 + double sigx2 = 0.68712938; // FIXME: this should be corrected analytically + GaussianBayesNet expected2 = simpleGaussian(ordering["x2"], zero(2), sigx2); + GaussianBayesNet actual2 = bayesTree.marginalBayesNet(ordering["x2"]); + CHECK(assert_equal(expected2,actual2,tol)); // FAILS + + // Check marginal on x3 + GaussianBayesNet expected3 = simpleGaussian(ordering["x3"], zero(2), sigmax3); + GaussianBayesNet actual3 = bayesTree.marginalBayesNet(ordering["x3"]); + CHECK(assert_equal(expected3,actual3,tol)); + + // Check marginal on x4 + GaussianBayesNet expected4 = simpleGaussian(ordering["x4"], zero(2), sigmax4); + GaussianBayesNet actual4 = bayesTree.marginalBayesNet(ordering["x4"]); + CHECK(assert_equal(expected4,actual4,tol)); + + // Check marginal on x7 (should be equal to x1) + GaussianBayesNet expected7 = simpleGaussian(ordering["x7"], zero(2), sigmax7); + GaussianBayesNet actual7 = bayesTree.marginalBayesNet(ordering["x7"]); + CHECK(assert_equal(expected7,actual7,tol)); +} /* ************************************************************************* */ TEST( BayesTree, balanced_smoother_shortcuts ) @@ -238,26 +237,33 @@ TEST( BayesTree, balanced_smoother_shortcuts ) } /* ************************************************************************* */ -// SL-FIX TEST( BayesTree, balanced_smoother_clique_marginals ) -//{ -// // Create smoother with 7 nodes -// GaussianFactorGraph smoother = createSmoother(7); -// Ordering ordering; -// ordering += "x1","x3","x5","x7","x2","x6","x4"; -// -// // Create the Bayes tree -// GaussianBayesNet chordalBayesNet = smoother.eliminate(ordering); -// GaussianISAM bayesTree(chordalBayesNet); -// -// // Check the clique marginal P(C3) -// double sigmax2_alt = 1/1.45533; // THIS NEEDS TO BE CHECKED! -// GaussianBayesNet expected = simpleGaussian("x2",zero(2),sigmax2_alt); -// push_front(expected,"x1", zero(2), eye(2)*sqrt(2), "x2", -eye(2)*sqrt(2)/2, ones(2)); -// GaussianISAM::sharedClique R = bayesTree.root(), C3 = bayesTree["x1"]; -// FactorGraph marginal = C3->marginal(R); -// GaussianBayesNet actual = eliminate(marginal,C3->keys()); -// CHECK(assert_equal(expected,actual,tol)); -//} +TEST( BayesTree, balanced_smoother_clique_marginals ) +{ + // Create smoother with 7 nodes + Ordering ordering; + ordering += "x1","x3","x5","x7","x2","x6","x4"; + GaussianFactorGraph smoother = createSmoother(7, ordering).first; + + // Create the Bayes tree + GaussianBayesNet chordalBayesNet = *Inference::Eliminate(smoother); + GaussianISAM bayesTree(chordalBayesNet); + + // Check the clique marginal P(C3) + double sigmax2_alt = 1/1.45533; // THIS NEEDS TO BE CHECKED! + GaussianBayesNet expected = simpleGaussian(ordering["x2"],zero(2),sigmax2_alt); + push_front(expected,ordering["x1"], zero(2), eye(2)*sqrt(2), ordering["x2"], -eye(2)*sqrt(2)/2, ones(2)); + GaussianISAM::sharedClique R = bayesTree.root(), C3 = bayesTree[ordering["x1"]]; + GaussianFactorGraph marginal = C3->marginal(R); + GaussianVariableIndex<> varIndex(marginal); + Permutation toFront(Permutation::PullToFront(C3->keys(), varIndex.size())); + Permutation toFrontInverse(*toFront.inverse()); + varIndex.permute(toFront); + BOOST_FOREACH(const GaussianFactor::shared_ptr& factor, marginal) { + factor->permuteWithInverse(toFrontInverse); } + GaussianBayesNet actual = *Inference::EliminateUntil(marginal, C3->keys().size(), varIndex); + actual.permuteWithInverse(toFront); + CHECK(assert_equal(expected,actual,tol)); +} /* ************************************************************************* */ // SL-FIX TEST( BayesTree, balanced_smoother_joint )