From 7fcd06bb4fba7e594e21d186c01cdf59fe0c42bb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 17 Sep 2012 14:03:54 +0000 Subject: [PATCH] BayesTree::marginalFactor now calls Clique::marginal2, which in turn calles the new function Clique::separatorMarginal. This calculates marginals with a much simpler recursion, using the parent separator marginal. This could be faster than the shortcut way, especially if separator sizes are small and the root clique is large. The cached marginals have to be discarded when the bayes tree is updated, but this is no different from shortcuts to the root. --- .../discrete/tests/testDiscreteBayesTree.cpp | 30 +++++++++-- gtsam/inference/BayesTree-inl.h | 27 +++++----- gtsam/inference/BayesTree.h | 4 +- gtsam/inference/BayesTreeCliqueBase-inl.h | 52 +++++++++++++++++++ gtsam/inference/BayesTreeCliqueBase.h | 12 +++++ gtsam/inference/GenericSequentialSolver-inl.h | 5 +- 6 files changed, 110 insertions(+), 20 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 68a3f45af..e3a771940 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -129,8 +129,9 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { // Check whether BN and BT give the same answer on all configurations // Also calculate all some marginals Vector marginals = zero(15); - double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0, - joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0; + double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, + joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, + joint_4_11 = 0; vector allPosbValues = cartesianProduct( key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); @@ -150,6 +151,8 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { joint_9_12_14 += actual; if (x[8] && x[12] & x[14]) joint_8_12_14 += actual; + if (x[8] && x[12]) + joint_8_12 += actual; if (x[8] && x[2]) joint82 += actual; if (x[1] && x[2]) @@ -165,9 +168,28 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { } DiscreteFactor::Values all1 = allPosbValues.back(); - // check shortcut P(S9||R) to root Clique::shared_ptr R = bayesTree.root(); - Clique::shared_ptr c = bayesTree[9]; + + // check separator marginal P(S0) + Clique::shared_ptr c = bayesTree[0]; + DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R, + EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); + + // check separator marginal P(S9), should be P(14) + c = bayesTree[9]; + DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R, + EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); + + // check separator marginal of root, should be empty + c = bayesTree[11]; + DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R, + EliminateDiscrete); + EXPECT_LONGS_EQUAL(0, separatorMarginal11.size()); + + // check shortcut P(S9||R) to root + c = bayesTree[9]; DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); EXPECT_LONGS_EQUAL(0, shortcut.size()); diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 3d5276be8..e09fdc83a 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -475,22 +475,23 @@ namespace gtsam { } } - /* ************************************************************************* */ - // First finds clique marginal then marginalizes that - /* ************************************************************************* */ - template - typename CONDITIONAL::FactorType::shared_ptr BayesTree::marginalFactor( - Index j, Eliminate function) const { + /* ************************************************************************* */ + // First finds clique marginal then marginalizes that + /* ************************************************************************* */ + template + typename CONDITIONAL::FactorType::shared_ptr BayesTree::marginalFactor( + Index j, Eliminate function) const { - // get clique containing Index j - sharedClique clique = (*this)[j]; + // get clique containing Index j + sharedClique clique = (*this)[j]; - // calculate or retrieve its marginal - FactorGraph cliqueMarginal = clique->marginal(root_,function); + // calculate or retrieve its marginal P(C) = P(F,S) + FactorGraph cliqueMarginal = clique->marginal2(root_,function); - return GenericSequentialSolver(cliqueMarginal).marginalFactor( - j, function); - } + // now, marginalize out everything that is not variable j + GenericSequentialSolver solver(cliqueMarginal); + return solver.marginalFactor(j, function); + } /* ************************************************************************* */ template diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 075b0b411..54e738559 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -219,7 +219,9 @@ namespace gtsam { void clear(); /** Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data */ - inline void deleteCachedShorcuts() { root_->deleteCachedShorcuts(); } + inline void deleteCachedShorcuts() { + root_->deleteCachedShorcuts(); + } /** * Remove path from clique to root and return that path as factors diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index 7d77a5150..bc98c28b5 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -212,6 +212,58 @@ namespace gtsam { return *solver.jointFactorGraph(conditional_->keys(), function); } + /* ************************************************************************* */ + // separator marginal, uses separator marginal of parent recursively + // P(C) = P(F|S) P(S) + /* ************************************************************************* */ + template + FactorGraph::FactorType> BayesTreeCliqueBase< + DERIVED, CONDITIONAL>::separatorMarginal(derived_ptr R, Eliminate function) const { + // Check if the Separator marginal was already calculated + if (!cachedSeparatorMarginal_) { + + // If this is the root, there is no separator + if (R.get() == this) { + // we are root, return empty + FactorGraph empty; + cachedSeparatorMarginal_ = empty; + } else { + // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) + // initialize P(Cp) with the parent separator marginal + derived_ptr parent(parent_.lock()); + FactorGraph p_Cp(parent->separatorMarginal(R, function)); // P(Sp) + // now add the parent cobnditional + p_Cp.push_back(parent->conditional()->toFactor()); // P(Fp|Sp) + + // Create solver that will marginalize for us + GenericSequentialSolver solver(p_Cp); + + // The variables we want to keep are exactly the ones in S + sharedConditional p_F_S = this->conditional(); + std::vector indicesS(p_F_S->beginParents(), p_F_S->endParents()); + + cachedSeparatorMarginal_ = *(solver.jointBayesNet(indicesS, function)); + } + } + + // return the shortcut P(S||B) + return *cachedSeparatorMarginal_; // return the cached version + } + + /* ************************************************************************* */ + // marginal2, uses separator marginal of parent recursively + // P(C) = P(F|S) P(S) + /* ************************************************************************* */ + template + FactorGraph::FactorType> BayesTreeCliqueBase< + DERIVED, CONDITIONAL>::marginal2(derived_ptr R, Eliminate function) const { + // initialize with separator marginal P(S) + FactorGraph p_C(this->separatorMarginal(R, function)); + // add the conditional P(F|S) + p_C.push_back(this->conditional()->toFactor()); + return p_C; + } + /* ************************************************************************* */ // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index 87711b3f2..edde8dedc 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -83,6 +83,9 @@ namespace gtsam { /// This stores the Cached Shortcut value mutable boost::optional > cachedShortcut_; + /// This stores the Cached separator margnal P(S) + mutable boost::optional > cachedSeparatorMarginal_; + public: sharedConditional conditional_; derived_weak_ptr parent_; @@ -192,6 +195,14 @@ namespace gtsam { FactorGraph marginal(derived_ptr root, Eliminate function) const; + /** return the conditional P(S|Root) on the separator given the root */ + FactorGraph separatorMarginal(derived_ptr root, + Eliminate function) const; + + /** return the marginal P(C) of the clique, using separator shortcuts */ + FactorGraph marginal2(derived_ptr root, + Eliminate function) const; + /** * return the joint P(C1,C2), where C1==this. TODO: not a method? * Limitation: can only calculate joint if cliques are disjoint or one of them is root @@ -233,6 +244,7 @@ namespace gtsam { /// Reset the computed shortcut of this clique. Used by friend BayesTree void resetCachedShortcut() { + cachedSeparatorMarginal_ = boost::none; cachedShortcut_ = boost::none; } diff --git a/gtsam/inference/GenericSequentialSolver-inl.h b/gtsam/inference/GenericSequentialSolver-inl.h index b5511dfd7..7800f4adb 100644 --- a/gtsam/inference/GenericSequentialSolver-inl.h +++ b/gtsam/inference/GenericSequentialSolver-inl.h @@ -200,8 +200,8 @@ namespace gtsam { const std::vector& js, Eliminate function) const { // Eliminate all variables - typename BayesNet::shared_ptr bayesNet = jointBayesNet(js, - function); + typename BayesNet::shared_ptr bayesNet = // + jointBayesNet(js, function); return boost::make_shared >(*bayesNet); } @@ -216,6 +216,7 @@ namespace gtsam { js[0] = j; // Call joint and return the only factor in the factor graph it returns + // TODO: just call jointBayesNet and grab last conditional, then toFactor.... return (*this->jointFactorGraph(js, function))[0]; }