BayesTree::marginalFactor now calls Clique::marginal2, which in turn calles the new function Clique::separatorMarginal. This calculates marginals with a much simpler recursion, using the parent separator marginal. This could be faster than the shortcut way, especially if separator sizes are small and the root clique is large. The cached marginals have to be discarded when the bayes tree is updated, but this is no different from shortcuts to the root.
parent
1f0cc0aaa4
commit
7fcd06bb4f
|
@ -129,8 +129,9 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
|||
// Check whether BN and BT give the same answer on all configurations
|
||||
// Also calculate all some marginals
|
||||
Vector marginals = zero(15);
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint82 = 0,
|
||||
joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, joint_4_11 = 0;
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||
joint_4_11 = 0;
|
||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
||||
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
|
@ -150,6 +151,8 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
|||
joint_9_12_14 += actual;
|
||||
if (x[8] && x[12] & x[14])
|
||||
joint_8_12_14 += actual;
|
||||
if (x[8] && x[12])
|
||||
joint_8_12 += actual;
|
||||
if (x[8] && x[2])
|
||||
joint82 += actual;
|
||||
if (x[1] && x[2])
|
||||
|
@ -165,9 +168,28 @@ TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
|||
}
|
||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
Clique::shared_ptr R = bayesTree.root();
|
||||
Clique::shared_ptr c = bayesTree[9];
|
||||
|
||||
// check separator marginal P(S0)
|
||||
Clique::shared_ptr c = bayesTree[0];
|
||||
DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
|
||||
EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
c = bayesTree[9];
|
||||
DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
|
||||
EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
c = bayesTree[11];
|
||||
DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
|
||||
EliminateDiscrete);
|
||||
EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
c = bayesTree[9];
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_LONGS_EQUAL(0, shortcut.size());
|
||||
|
||||
|
|
|
@ -475,22 +475,23 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// First finds clique marginal then marginalizes that
|
||||
/* ************************************************************************* */
|
||||
template<class CONDITIONAL, class CLIQUE>
|
||||
typename CONDITIONAL::FactorType::shared_ptr BayesTree<CONDITIONAL,CLIQUE>::marginalFactor(
|
||||
Index j, Eliminate function) const {
|
||||
/* ************************************************************************* */
|
||||
// First finds clique marginal then marginalizes that
|
||||
/* ************************************************************************* */
|
||||
template<class CONDITIONAL, class CLIQUE>
|
||||
typename CONDITIONAL::FactorType::shared_ptr BayesTree<CONDITIONAL,CLIQUE>::marginalFactor(
|
||||
Index j, Eliminate function) const {
|
||||
|
||||
// get clique containing Index j
|
||||
sharedClique clique = (*this)[j];
|
||||
// get clique containing Index j
|
||||
sharedClique clique = (*this)[j];
|
||||
|
||||
// calculate or retrieve its marginal
|
||||
FactorGraph<FactorType> cliqueMarginal = clique->marginal(root_,function);
|
||||
// calculate or retrieve its marginal P(C) = P(F,S)
|
||||
FactorGraph<FactorType> cliqueMarginal = clique->marginal2(root_,function);
|
||||
|
||||
return GenericSequentialSolver<FactorType>(cliqueMarginal).marginalFactor(
|
||||
j, function);
|
||||
}
|
||||
// now, marginalize out everything that is not variable j
|
||||
GenericSequentialSolver<FactorType> solver(cliqueMarginal);
|
||||
return solver.marginalFactor(j, function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class CONDITIONAL, class CLIQUE>
|
||||
|
|
|
@ -219,7 +219,9 @@ namespace gtsam {
|
|||
void clear();
|
||||
|
||||
/** Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data */
|
||||
inline void deleteCachedShorcuts() { root_->deleteCachedShorcuts(); }
|
||||
inline void deleteCachedShorcuts() {
|
||||
root_->deleteCachedShorcuts();
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove path from clique to root and return that path as factors
|
||||
|
|
|
@ -212,6 +212,58 @@ namespace gtsam {
|
|||
return *solver.jointFactorGraph(conditional_->keys(), function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// separator marginal, uses separator marginal of parent recursively
|
||||
// P(C) = P(F|S) P(S)
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
|
||||
DERIVED, CONDITIONAL>::separatorMarginal(derived_ptr R, Eliminate function) const {
|
||||
// Check if the Separator marginal was already calculated
|
||||
if (!cachedSeparatorMarginal_) {
|
||||
|
||||
// If this is the root, there is no separator
|
||||
if (R.get() == this) {
|
||||
// we are root, return empty
|
||||
FactorGraph<FactorType> empty;
|
||||
cachedSeparatorMarginal_ = empty;
|
||||
} else {
|
||||
// Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp)
|
||||
// initialize P(Cp) with the parent separator marginal
|
||||
derived_ptr parent(parent_.lock());
|
||||
FactorGraph<FactorType> p_Cp(parent->separatorMarginal(R, function)); // P(Sp)
|
||||
// now add the parent cobnditional
|
||||
p_Cp.push_back(parent->conditional()->toFactor()); // P(Fp|Sp)
|
||||
|
||||
// Create solver that will marginalize for us
|
||||
GenericSequentialSolver<FactorType> solver(p_Cp);
|
||||
|
||||
// The variables we want to keep are exactly the ones in S
|
||||
sharedConditional p_F_S = this->conditional();
|
||||
std::vector<Index> indicesS(p_F_S->beginParents(), p_F_S->endParents());
|
||||
|
||||
cachedSeparatorMarginal_ = *(solver.jointBayesNet(indicesS, function));
|
||||
}
|
||||
}
|
||||
|
||||
// return the shortcut P(S||B)
|
||||
return *cachedSeparatorMarginal_; // return the cached version
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// marginal2, uses separator marginal of parent recursively
|
||||
// P(C) = P(F|S) P(S)
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
|
||||
DERIVED, CONDITIONAL>::marginal2(derived_ptr R, Eliminate function) const {
|
||||
// initialize with separator marginal P(S)
|
||||
FactorGraph<FactorType> p_C(this->separatorMarginal(R, function));
|
||||
// add the conditional P(F|S)
|
||||
p_C.push_back(this->conditional()->toFactor());
|
||||
return p_C;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -83,6 +83,9 @@ namespace gtsam {
|
|||
/// This stores the Cached Shortcut value
|
||||
mutable boost::optional<BayesNet<ConditionalType> > cachedShortcut_;
|
||||
|
||||
/// This stores the Cached separator margnal P(S)
|
||||
mutable boost::optional<FactorGraph<FactorType> > cachedSeparatorMarginal_;
|
||||
|
||||
public:
|
||||
sharedConditional conditional_;
|
||||
derived_weak_ptr parent_;
|
||||
|
@ -192,6 +195,14 @@ namespace gtsam {
|
|||
FactorGraph<FactorType> marginal(derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/** return the conditional P(S|Root) on the separator given the root */
|
||||
FactorGraph<FactorType> separatorMarginal(derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/** return the marginal P(C) of the clique, using separator shortcuts */
|
||||
FactorGraph<FactorType> marginal2(derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/**
|
||||
* return the joint P(C1,C2), where C1==this. TODO: not a method?
|
||||
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
|
||||
|
@ -233,6 +244,7 @@ namespace gtsam {
|
|||
|
||||
/// Reset the computed shortcut of this clique. Used by friend BayesTree
|
||||
void resetCachedShortcut() {
|
||||
cachedSeparatorMarginal_ = boost::none;
|
||||
cachedShortcut_ = boost::none;
|
||||
}
|
||||
|
||||
|
|
|
@ -200,8 +200,8 @@ namespace gtsam {
|
|||
const std::vector<Index>& js, Eliminate function) const {
|
||||
|
||||
// Eliminate all variables
|
||||
typename BayesNet<Conditional>::shared_ptr bayesNet = jointBayesNet(js,
|
||||
function);
|
||||
typename BayesNet<Conditional>::shared_ptr bayesNet = //
|
||||
jointBayesNet(js, function);
|
||||
|
||||
return boost::make_shared<FactorGraph<FACTOR> >(*bayesNet);
|
||||
}
|
||||
|
@ -216,6 +216,7 @@ namespace gtsam {
|
|||
js[0] = j;
|
||||
|
||||
// Call joint and return the only factor in the factor graph it returns
|
||||
// TODO: just call jointBayesNet and grab last conditional, then toFactor....
|
||||
return (*this->jointFactorGraph(js, function))[0];
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue