BayesTree::marginalFactor now calls Clique::marginal2, which in turn calles the new function Clique::separatorMarginal. This calculates marginals with a much simpler recursion, using the parent separator marginal. This could be faster than the shortcut way, especially if separator sizes are small and the root clique is large. The cached marginals have to be discarded when the bayes tree is updated, but this is no different from shortcuts to the root.

release/4.3a0
Frank Dellaert 2012-09-17 14:03:54 +00:00
parent 1f0cc0aaa4
commit 7fcd06bb4f
6 changed files with 110 additions and 20 deletions

View File

@ -129,8 +129,9 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
// Check whether BN and BT give the same answer on all configurations
// Also calculate all some marginals
Vector marginals = zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0;
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
@ -150,6 +151,8 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
joint_9_12_14 += actual;
if (x[8] && x[12] & x[14])
joint_8_12_14 += actual;
if (x[8] && x[12])
joint_8_12 += actual;
if (x[8] && x[2])
joint82 += actual;
if (x[1] && x[2])
@ -165,9 +168,28 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
}
DiscreteFactor::Values all1 = allPosbValues.back();
// check shortcut P(S9||R) to root
Clique::shared_ptr R = bayesTree.root();
Clique::shared_ptr c = bayesTree[9];
// check separator marginal P(S0)
Clique::shared_ptr c = bayesTree[0];
DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
c = bayesTree[9];
DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
c = bayesTree[11];
DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root
c = bayesTree[9];
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, shortcut.size());

View File

@ -485,11 +485,12 @@ namespace gtsam {
// get clique containing Index j
sharedClique clique = (*this)[j];
// calculate or retrieve its marginal
FactorGraph<FactorType> cliqueMarginal = clique->marginal(root_,function);
// calculate or retrieve its marginal P(C) = P(F,S)
FactorGraph<FactorType> cliqueMarginal = clique->marginal2(root_,function);
return GenericSequentialSolver<FactorType>(cliqueMarginal).marginalFactor(
j, function);
// now, marginalize out everything that is not variable j
GenericSequentialSolver<FactorType> solver(cliqueMarginal);
return solver.marginalFactor(j, function);
}
/* ************************************************************************* */

View File

@ -219,7 +219,9 @@ namespace gtsam {
void clear();
/** Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data */
inline void deleteCachedShorcuts() { root_->deleteCachedShorcuts(); }
inline void deleteCachedShorcuts() {
root_->deleteCachedShorcuts();
}
/**
* Remove path from clique to root and return that path as factors

View File

@ -212,6 +212,58 @@ namespace gtsam {
return *solver.jointFactorGraph(conditional_->keys(), function);
}
/* ************************************************************************* */
// separator marginal, uses separator marginal of parent recursively
// P(C) = P(F|S) P(S)
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
DERIVED, CONDITIONAL>::separatorMarginal(derived_ptr R, Eliminate function) const {
// Check if the Separator marginal was already calculated
if (!cachedSeparatorMarginal_) {
// If this is the root, there is no separator
if (R.get() == this) {
// we are root, return empty
FactorGraph<FactorType> empty;
cachedSeparatorMarginal_ = empty;
} else {
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
// initialize P(Cp) with the parent separator marginal
derived_ptr parent(parent_.lock());
FactorGraph<FactorType> p_Cp(parent->separatorMarginal(R, function)); // P(Sp)
// now add the parent cobnditional
p_Cp.push_back(parent->conditional()->toFactor()); // P(Fp|Sp)
// Create solver that will marginalize for us
GenericSequentialSolver<FactorType> solver(p_Cp);
// The variables we want to keep are exactly the ones in S
sharedConditional p_F_S = this->conditional();
std::vector<Index> indicesS(p_F_S->beginParents(), p_F_S->endParents());
cachedSeparatorMarginal_ = *(solver.jointBayesNet(indicesS, function));
}
}
// return the shortcut P(S||B)
return *cachedSeparatorMarginal_; // return the cached version
}
/* ************************************************************************* */
// marginal2, uses separator marginal of parent recursively
// P(C) = P(F|S) P(S)
/* ************************************************************************* */
template<class DERIVED, class CONDITIONAL>
FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
DERIVED, CONDITIONAL>::marginal2(derived_ptr R, Eliminate function) const {
// initialize with separator marginal P(S)
FactorGraph<FactorType> p_C(this->separatorMarginal(R, function));
// add the conditional P(F|S)
p_C.push_back(this->conditional()->toFactor());
return p_C;
}
/* ************************************************************************* */
// P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
/* ************************************************************************* */

View File

@ -83,6 +83,9 @@ namespace gtsam {
/// This stores the Cached Shortcut value
mutable boost::optional<BayesNet<ConditionalType> > cachedShortcut_;
/// This stores the Cached separator margnal P(S)
mutable boost::optional<FactorGraph<FactorType> > cachedSeparatorMarginal_;
public:
sharedConditional conditional_;
derived_weak_ptr parent_;
@ -192,6 +195,14 @@ namespace gtsam {
FactorGraph<FactorType> marginal(derived_ptr root,
Eliminate function) const;
/** return the conditional P(S|Root) on the separator given the root */
FactorGraph<FactorType> separatorMarginal(derived_ptr root,
Eliminate function) const;
/** return the marginal P(C) of the clique, using separator shortcuts */
FactorGraph<FactorType> marginal2(derived_ptr root,
Eliminate function) const;
/**
* return the joint P(C1,C2), where C1==this. TODO: not a method?
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
@ -233,6 +244,7 @@ namespace gtsam {
/// Reset the computed shortcut of this clique. Used by friend BayesTree
void resetCachedShortcut() {
cachedSeparatorMarginal_ = boost::none;
cachedShortcut_ = boost::none;
}

View File

@ -200,8 +200,8 @@ namespace gtsam {
const std::vector<Index>& js, Eliminate function) const {
// Eliminate all variables
typename BayesNet<Conditional>::shared_ptr bayesNet = jointBayesNet(js,
function);
typename BayesNet<Conditional>::shared_ptr bayesNet = //
jointBayesNet(js, function);
return boost::make_shared<FactorGraph<FACTOR> >(*bayesNet);
}
@ -216,6 +216,7 @@ namespace gtsam {
js[0] = j;
// Call joint and return the only factor in the factor graph it returns
// TODO: just call jointBayesNet and grab last conditional, then toFactor....
return (*this->jointFactorGraph(js, function))[0];
}