diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 68a3f45af..e3a771940 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -129,8 +129,9 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { // Check whether BN and BT give the same answer on all configurations // Also calculate all some marginals Vector marginals = zero(15); - double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0, - joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0; + double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, + joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, + joint_4_11 = 0; vector allPosbValues = cartesianProduct( key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); @@ -150,6 +151,8 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { joint_9_12_14 += actual; if (x[8] && x[12] & x[14]) joint_8_12_14 += actual; + if (x[8] && x[12]) + joint_8_12 += actual; if (x[8] && x[2]) joint82 += actual; if (x[1] && x[2]) @@ -165,9 +168,28 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) { } DiscreteFactor::Values all1 = allPosbValues.back(); - // check shortcut P(S9||R) to root Clique::shared_ptr R = bayesTree.root(); - Clique::shared_ptr c = bayesTree[9]; + + // check separator marginal P(S0) + Clique::shared_ptr c = bayesTree[0]; + DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R, + EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); + + // check separator marginal P(S9), should be P(14) + c = bayesTree[9]; + DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R, + EliminateDiscrete); + EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); + + // check separator marginal of root, should be empty + c = bayesTree[11]; + DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R, + EliminateDiscrete); + EXPECT_LONGS_EQUAL(0, separatorMarginal11.size()); + + // check shortcut P(S9||R) to root + c = bayesTree[9]; DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); EXPECT_LONGS_EQUAL(0, shortcut.size()); diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 3d5276be8..e09fdc83a 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -475,22 +475,23 @@ namespace gtsam { } } - /* ************************************************************************* */ - // First finds clique marginal then marginalizes that - /* ************************************************************************* */ - template - typename CONDITIONAL::FactorType::shared_ptr BayesTree::marginalFactor( - Index j, Eliminate function) const { + /* ************************************************************************* */ + // First finds clique marginal then marginalizes that + /* ************************************************************************* */ + template + typename CONDITIONAL::FactorType::shared_ptr BayesTree::marginalFactor( + Index j, Eliminate function) const { - // get clique containing Index j - sharedClique clique = (*this)[j]; + // get clique containing Index j + sharedClique clique = (*this)[j]; - // calculate or retrieve its marginal - FactorGraph cliqueMarginal = clique->marginal(root_,function); + // calculate or retrieve its marginal P(C) = P(F,S) + FactorGraph cliqueMarginal = clique->marginal2(root_,function); - return GenericSequentialSolver(cliqueMarginal).marginalFactor( - j, function); - } + // now, marginalize out everything that is not variable j + GenericSequentialSolver solver(cliqueMarginal); + return solver.marginalFactor(j, function); + } /* ************************************************************************* */ template diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 075b0b411..54e738559 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -219,7 +219,9 @@ namespace gtsam { void clear(); /** Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data */ - inline void deleteCachedShorcuts() { root_->deleteCachedShorcuts(); } + inline void deleteCachedShorcuts() { + root_->deleteCachedShorcuts(); + } /** * Remove path from clique to root and return that path as factors diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index 7d77a5150..bc98c28b5 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -212,6 +212,58 @@ namespace gtsam { return *solver.jointFactorGraph(conditional_->keys(), function); } + /* ************************************************************************* */ + // separator marginal, uses separator marginal of parent recursively + // P(C) = P(F|S) P(S) + /* ************************************************************************* */ + template + FactorGraph::FactorType> BayesTreeCliqueBase< + DERIVED, CONDITIONAL>::separatorMarginal(derived_ptr R, Eliminate function) const { + // Check if the Separator marginal was already calculated + if (!cachedSeparatorMarginal_) { + + // If this is the root, there is no separator + if (R.get() == this) { + // we are root, return empty + FactorGraph empty; + cachedSeparatorMarginal_ = empty; + } else { + // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) + // initialize P(Cp) with the parent separator marginal + derived_ptr parent(parent_.lock()); + FactorGraph p_Cp(parent->separatorMarginal(R, function)); // P(Sp) + // now add the parent cobnditional + p_Cp.push_back(parent->conditional()->toFactor()); // P(Fp|Sp) + + // Create solver that will marginalize for us + GenericSequentialSolver solver(p_Cp); + + // The variables we want to keep are exactly the ones in S + sharedConditional p_F_S = this->conditional(); + std::vector indicesS(p_F_S->beginParents(), p_F_S->endParents()); + + cachedSeparatorMarginal_ = *(solver.jointBayesNet(indicesS, function)); + } + } + + // return the shortcut P(S||B) + return *cachedSeparatorMarginal_; // return the cached version + } + + /* ************************************************************************* */ + // marginal2, uses separator marginal of parent recursively + // P(C) = P(F|S) P(S) + /* ************************************************************************* */ + template + FactorGraph::FactorType> BayesTreeCliqueBase< + DERIVED, CONDITIONAL>::marginal2(derived_ptr R, Eliminate function) const { + // initialize with separator marginal P(S) + FactorGraph p_C(this->separatorMarginal(R, function)); + // add the conditional P(F|S) + p_C.push_back(this->conditional()->toFactor()); + return p_C; + } + /* ************************************************************************* */ // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index 87711b3f2..edde8dedc 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -83,6 +83,9 @@ namespace gtsam { /// This stores the Cached Shortcut value mutable boost::optional > cachedShortcut_; + /// This stores the Cached separator margnal P(S) + mutable boost::optional > cachedSeparatorMarginal_; + public: sharedConditional conditional_; derived_weak_ptr parent_; @@ -192,6 +195,14 @@ namespace gtsam { FactorGraph marginal(derived_ptr root, Eliminate function) const; + /** return the conditional P(S|Root) on the separator given the root */ + FactorGraph separatorMarginal(derived_ptr root, + Eliminate function) const; + + /** return the marginal P(C) of the clique, using separator shortcuts */ + FactorGraph marginal2(derived_ptr root, + Eliminate function) const; + /** * return the joint P(C1,C2), where C1==this. TODO: not a method? * Limitation: can only calculate joint if cliques are disjoint or one of them is root @@ -233,6 +244,7 @@ namespace gtsam { /// Reset the computed shortcut of this clique. Used by friend BayesTree void resetCachedShortcut() { + cachedSeparatorMarginal_ = boost::none; cachedShortcut_ = boost::none; } diff --git a/gtsam/inference/GenericSequentialSolver-inl.h b/gtsam/inference/GenericSequentialSolver-inl.h index b5511dfd7..7800f4adb 100644 --- a/gtsam/inference/GenericSequentialSolver-inl.h +++ b/gtsam/inference/GenericSequentialSolver-inl.h @@ -200,8 +200,8 @@ namespace gtsam { const std::vector& js, Eliminate function) const { // Eliminate all variables - typename BayesNet::shared_ptr bayesNet = jointBayesNet(js, - function); + typename BayesNet::shared_ptr bayesNet = // + jointBayesNet(js, function); return boost::make_shared >(*bayesNet); } @@ -216,6 +216,7 @@ namespace gtsam { js[0] = j; // Call joint and return the only factor in the factor graph it returns + // TODO: just call jointBayesNet and grab last conditional, then toFactor.... return (*this->jointFactorGraph(js, function))[0]; }