Fixed shortcuts after adding several more problematic testcases
parent
34a9000134
commit
cdf45105c2
|
@ -863,6 +863,14 @@
|
|||
<useDefaultCommand>true</useDefaultCommand>
|
||||
<runAllBuilders>true</runAllBuilders>
|
||||
</target>
|
||||
<target name="testSymbolicBayesTree.run" path="build/gtsam/inference" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
|
||||
<buildCommand>make</buildCommand>
|
||||
<buildArguments>-j1</buildArguments>
|
||||
<buildTarget>testSymbolicBayesTree.run</buildTarget>
|
||||
<stopOnError>true</stopOnError>
|
||||
<useDefaultCommand>false</useDefaultCommand>
|
||||
<runAllBuilders>true</runAllBuilders>
|
||||
</target>
|
||||
<target name="vSFMexample.run" path="build/examples/vSLAMexample" targetID="org.eclipse.cdt.build.MakeTargetBuilder">
|
||||
<buildCommand>make</buildCommand>
|
||||
<buildArguments>-j2</buildArguments>
|
||||
|
|
|
@ -36,40 +36,6 @@ class Clique: public BayesTreeCliqueBase<Clique, DiscreteConditional> {
|
|||
|
||||
protected:
|
||||
|
||||
/// Calculate set S\B
|
||||
vector<Index> separatorShortcutVariables(derived_ptr B) const {
|
||||
sharedConditional p_F_S = this->conditional();
|
||||
vector<Index> &indicesB = B->conditional()->keys();
|
||||
vector<Index> S_setminus_B;
|
||||
set_difference(p_F_S->beginParents(), p_F_S->endParents(), //
|
||||
indicesB.begin(), indicesB.end(), back_inserter(S_setminus_B));
|
||||
return S_setminus_B;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine variable indices to keep in recursive separator shortcut calculation
|
||||
* The factor graph p_Cp_B has keys from the parent clique Cp and from B.
|
||||
* But we only keep the variables not in S union B.
|
||||
*/
|
||||
vector<Index> indices(derived_ptr B,
|
||||
const FactorGraph<FactorType>& p_Cp_B) const {
|
||||
|
||||
// We do this by first merging S and B
|
||||
sharedConditional p_F_S = this->conditional();
|
||||
vector<Index> &indicesB = B->conditional()->keys();
|
||||
vector<Index> S_union_B;
|
||||
set_union(p_F_S->beginParents(), p_F_S->endParents(), //
|
||||
indicesB.begin(), indicesB.end(), back_inserter(S_union_B));
|
||||
|
||||
// then intersecting S_union_B with all keys in p_Cp_B
|
||||
set<Index> allKeys = p_Cp_B.keys();
|
||||
vector<Index> keepers;
|
||||
set_intersection(S_union_B.begin(), S_union_B.end(), //
|
||||
allKeys.begin(), allKeys.end(), back_inserter(keepers));
|
||||
|
||||
return keepers;
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
typedef BayesTreeCliqueBase<Clique, DiscreteConditional> Base;
|
||||
|
@ -102,88 +68,6 @@ public:
|
|||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Separator shortcut function P(S||B) = P(S\B|B)
|
||||
* where S is a clique separator, and B any node (e.g., a brancing in the tree)
|
||||
* We can compute it recursively from the parent shortcut
|
||||
* P(Sp||B) as \int P(Fp|Sp) P(Sp||B), where Fp are the frontal nodes in p
|
||||
*/
|
||||
FactorGraph<FactorType>::shared_ptr separatorShortcut(derived_ptr B) const {
|
||||
|
||||
typedef FactorGraph<FactorType> FG;
|
||||
|
||||
FG::shared_ptr p_S_B; //shortcut P(S||B) This is empty now
|
||||
|
||||
// We only calculate the shortcut when this clique is not B
|
||||
// and when the S\B is not empty
|
||||
vector<Index> S_setminus_B = separatorShortcutVariables(B);
|
||||
if (B.get() != this && !S_setminus_B.empty()) {
|
||||
|
||||
// Obtain P(Fp|Sp) as a factor
|
||||
derived_ptr parent(parent_.lock());
|
||||
boost::shared_ptr<FactorType> p_Fp_Sp = parent->conditional()->toFactor();
|
||||
|
||||
// Obtain the parent shortcut P(Sp|B) as factors
|
||||
// TODO: really annoying that we eliminate more than we have to !
|
||||
// TODO: we should only eliminate C_p\B, with S\B variables last
|
||||
// TODO: and this index dance will be easier then, as well
|
||||
FG p_Sp_B(parent->shortcut(B, &EliminateDiscrete));
|
||||
|
||||
// now combine P(Cp||B) = P(Fp|Sp) * P(Sp||B)
|
||||
boost::shared_ptr<FG> p_Cp_B(new FG);
|
||||
p_Cp_B->push_back(p_Fp_Sp);
|
||||
p_Cp_B->push_back(p_Sp_B);
|
||||
|
||||
// Figure out how many variables there are in in the shortcut
|
||||
// size_t nVariables = *max_element(S_setminus_B.begin(),S_setminus_B.end());
|
||||
// cout << "nVariables: " << nVariables << endl;
|
||||
// VariableIndex::shared_ptr structure(new VariableIndex(*p_Cp_B));
|
||||
// GTSAM_PRINT(*p_Cp_B);
|
||||
// GTSAM_PRINT(*structure);
|
||||
|
||||
// Create a generic solver that will marginalize for us
|
||||
GenericSequentialSolver<FactorType> solver(*p_Cp_B);
|
||||
|
||||
// The factor graph above will have keys from the parent clique Cp and from B.
|
||||
// But we only keep the variables not in S union B.
|
||||
vector<Index> keepers = indices(B, *p_Cp_B);
|
||||
|
||||
p_S_B = solver.jointFactorGraph(keepers, &EliminateDiscrete);
|
||||
}
|
||||
// return the shortcut P(S||B)
|
||||
return p_S_B;
|
||||
}
|
||||
|
||||
/**
|
||||
* The shortcut density is a conditional P(S||B) of the separator of this
|
||||
* clique on the clique B.
|
||||
*/
|
||||
BayesNet<DiscreteConditional> shortcut(derived_ptr B,
|
||||
Eliminate function) const {
|
||||
|
||||
//Check if the ShortCut already exists
|
||||
if (cachedShortcut_) {
|
||||
return *cachedShortcut_; // return the cached version
|
||||
} else {
|
||||
BayesNet<DiscreteConditional> bn;
|
||||
FactorGraph<FactorType>::shared_ptr fg = separatorShortcut(B);
|
||||
if (fg) {
|
||||
// calculate set S\B of indices to keep in Bayes net
|
||||
vector<Index> S_setminus_B = separatorShortcutVariables(B);
|
||||
set<Index> keep(S_setminus_B.begin(), S_setminus_B.end());
|
||||
|
||||
BOOST_FOREACH (FactorType::shared_ptr factor,*fg) {
|
||||
DecisionTreeFactor::shared_ptr df = boost::dynamic_pointer_cast<
|
||||
DecisionTreeFactor>(factor);
|
||||
if (keep.count(*factor->begin()))
|
||||
bn.push_front(boost::make_shared<DiscreteConditional>(1, *df));
|
||||
}
|
||||
}
|
||||
cachedShortcut_ = bn;
|
||||
return bn;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
typedef BayesTree<DiscreteConditional, Clique> DiscreteBayesTree;
|
||||
|
@ -196,7 +80,7 @@ double evaluate(const DiscreteBayesTree& tree,
|
|||
|
||||
/* ************************************************************************* */
|
||||
|
||||
TEST_UNSAFE( DiscreteMarginals, thinTree ) {
|
||||
TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
||||
|
||||
const int nrNodes = 15;
|
||||
const size_t nrStates = 2;
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::assertInvariants() const {
|
||||
void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::assertInvariants() const {
|
||||
#ifndef NDEBUG
|
||||
// We rely on the keys being sorted
|
||||
// FastVector<Index> sortedUniqueKeys(conditional_->begin(), conditional_->end());
|
||||
|
@ -35,27 +35,71 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
BayesTreeCliqueBase<DERIVED,CONDITIONAL>::BayesTreeCliqueBase(const sharedConditional& conditional) :
|
||||
conditional_(conditional) {
|
||||
std::vector<Index> BayesTreeCliqueBase<DERIVED, CONDITIONAL>::separator_setminus_B(
|
||||
derived_ptr B) const {
|
||||
sharedConditional p_F_S = this->conditional();
|
||||
std::vector<Index> &indicesB = B->conditional()->keys();
|
||||
std::vector<Index> S_setminus_B;
|
||||
std::set_difference(p_F_S->beginParents(), p_F_S->endParents(), //
|
||||
indicesB.begin(), indicesB.end(), back_inserter(S_setminus_B));
|
||||
return S_setminus_B;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
std::vector<Index> BayesTreeCliqueBase<DERIVED, CONDITIONAL>::shortcut_indices(
|
||||
derived_ptr B, const FactorGraph<FactorType>& p_Cp_B) const {
|
||||
std::set<Index> allKeys = p_Cp_B.keys();
|
||||
std::vector<Index> &indicesB = B->conditional()->keys();
|
||||
std::vector<Index> keep;
|
||||
#ifdef OLD_INDICES
|
||||
// We do this by first merging S and B
|
||||
sharedConditional p_F_S = this->conditional();
|
||||
std::vector<Index> S_union_B;
|
||||
std::set_union(p_F_S->beginParents(), p_F_S->endParents(),//
|
||||
indicesB.begin(), indicesB.end(), back_inserter(S_union_B));
|
||||
|
||||
// then intersecting S_union_B with all keys in p_Cp_B
|
||||
std::set_intersection(S_union_B.begin(), S_union_B.end(),//
|
||||
allKeys.begin(), allKeys.end(), back_inserter(keep));
|
||||
#else
|
||||
std::vector<Index> S_setminus_B = separator_setminus_B(B); // TODO, get as argument?
|
||||
std::set_intersection(S_setminus_B.begin(), S_setminus_B.end(), //
|
||||
allKeys.begin(), allKeys.end(), back_inserter(keep));
|
||||
std::set_intersection(indicesB.begin(), indicesB.end(), //
|
||||
allKeys.begin(), allKeys.end(), back_inserter(keep));
|
||||
#endif
|
||||
// BOOST_FOREACH(Index j, keep) std::cout << j << " "; std::cout << std::endl;
|
||||
return keep;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
BayesTreeCliqueBase<DERIVED, CONDITIONAL>::BayesTreeCliqueBase(
|
||||
const sharedConditional& conditional) :
|
||||
conditional_(conditional) {
|
||||
assertInvariants();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
BayesTreeCliqueBase<DERIVED,CONDITIONAL>::BayesTreeCliqueBase(const std::pair<sharedConditional, boost::shared_ptr<typename ConditionalType::FactorType> >& result) :
|
||||
conditional_(result.first) {
|
||||
BayesTreeCliqueBase<DERIVED, CONDITIONAL>::BayesTreeCliqueBase(
|
||||
const std::pair<sharedConditional,
|
||||
boost::shared_ptr<typename ConditionalType::FactorType> >& result) :
|
||||
conditional_(result.first) {
|
||||
assertInvariants();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::print(const std::string& s, const IndexFormatter& indexFormatter) const {
|
||||
void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::print(const std::string& s,
|
||||
const IndexFormatter& indexFormatter) const {
|
||||
conditional_->print(s, indexFormatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
size_t BayesTreeCliqueBase<DERIVED,CONDITIONAL>::treeSize() const {
|
||||
size_t BayesTreeCliqueBase<DERIVED, CONDITIONAL>::treeSize() const {
|
||||
size_t size = 1;
|
||||
BOOST_FOREACH(const derived_ptr& child, children_)
|
||||
size += child->treeSize();
|
||||
|
@ -64,15 +108,17 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::printTree(const std::string& indent, const IndexFormatter& indexFormatter) const {
|
||||
void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::printTree(
|
||||
const std::string& indent, const IndexFormatter& indexFormatter) const {
|
||||
asDerived(this)->print(indent, indexFormatter);
|
||||
BOOST_FOREACH(const derived_ptr& child, children_)
|
||||
child->printTree(indent+" ", indexFormatter);
|
||||
child->printTree(indent + " ", indexFormatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::permuteWithInverse(const Permutation& inversePermutation) {
|
||||
void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::permuteWithInverse(
|
||||
const Permutation& inversePermutation) {
|
||||
conditional_->permuteWithInverse(inversePermutation);
|
||||
BOOST_FOREACH(const derived_ptr& child, children_) {
|
||||
child->permuteWithInverse(inversePermutation);
|
||||
|
@ -82,19 +128,21 @@ namespace gtsam {
|
|||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
bool BayesTreeCliqueBase<DERIVED,CONDITIONAL>::permuteSeparatorWithInverse(const Permutation& inversePermutation) {
|
||||
bool changed = conditional_->permuteSeparatorWithInverse(inversePermutation);
|
||||
bool BayesTreeCliqueBase<DERIVED, CONDITIONAL>::permuteSeparatorWithInverse(
|
||||
const Permutation& inversePermutation) {
|
||||
bool changed = conditional_->permuteSeparatorWithInverse(
|
||||
inversePermutation);
|
||||
#ifndef NDEBUG
|
||||
if(!changed) {
|
||||
BOOST_FOREACH(Index& separatorKey, conditional_->parents()) { assert(separatorKey == inversePermutation[separatorKey]); }
|
||||
BOOST_FOREACH(Index& separatorKey, conditional_->parents()) {assert(separatorKey == inversePermutation[separatorKey]);}
|
||||
BOOST_FOREACH(const derived_ptr& child, children_) {
|
||||
assert(child->permuteSeparatorWithInverse(inversePermutation) == false);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if(changed) {
|
||||
if (changed) {
|
||||
BOOST_FOREACH(const derived_ptr& child, children_) {
|
||||
(void)child->permuteSeparatorWithInverse(inversePermutation);
|
||||
(void) child->permuteSeparatorWithInverse(inversePermutation);
|
||||
}
|
||||
}
|
||||
assertInvariants();
|
||||
|
@ -108,105 +156,47 @@ namespace gtsam {
|
|||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
BayesNet<CONDITIONAL> BayesTreeCliqueBase<DERIVED, CONDITIONAL>::shortcut(
|
||||
derived_ptr R, Eliminate function) const{
|
||||
derived_ptr B, Eliminate function) const {
|
||||
|
||||
static const bool debug = false;
|
||||
// Check if the ShortCut already exists
|
||||
if (!cachedShortcut_) {
|
||||
|
||||
BayesNet<ConditionalType> p_S_R; //shortcut P(S|R) This is empty now
|
||||
// We only calculate the shortcut when this clique is not B
|
||||
// and when the S\B is not empty
|
||||
std::vector<Index> S_setminus_B = separator_setminus_B(B);
|
||||
if (B.get() != this && !S_setminus_B.empty()) {
|
||||
|
||||
//Check if the ShortCut already exists
|
||||
if(!cachedShortcut_){
|
||||
// Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
|
||||
derived_ptr parent(parent_.lock());
|
||||
FactorGraph<FactorType> p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
|
||||
p_Cp_B.push_back(parent->conditional()->toFactor()); // P(Fp|Sp)
|
||||
|
||||
// A first base case is when this clique or its parent is the root,
|
||||
// in which case we return an empty Bayes net.
|
||||
// Add the root conditional
|
||||
// TODO: this is needed because otherwise we will be solving singular
|
||||
// systems and exceptions are thrown. However, we should be able to omit
|
||||
// this if we can get ATTEMPT_AT_NOT_ELIMINATING_ALL in
|
||||
// GenericSequentialSolver.* working...
|
||||
p_Cp_B.push_back(B->conditional()->toFactor()); // P(B)
|
||||
|
||||
derived_ptr parent(parent_.lock());
|
||||
if (R.get() != this && parent != R) {
|
||||
// Create solver that will marginalize for us
|
||||
GenericSequentialSolver<FactorType> solver(p_Cp_B);
|
||||
|
||||
// The root conditional
|
||||
FactorGraph<FactorType> p_R(BayesNet<ConditionalType>(R->conditional()));
|
||||
// Determine the variables we want to keep
|
||||
std::vector<Index> keep = shortcut_indices(B, p_Cp_B);
|
||||
|
||||
// The parent clique has a ConditionalType for each frontal node in Fp
|
||||
// so we can obtain P(Fp|Sp) in factor graph form
|
||||
FactorGraph<FactorType> p_Fp_Sp(BayesNet<ConditionalType>(parent->conditional()));
|
||||
// Finally, we only want to have S\B variables in the Bayes net, so
|
||||
size_t nrFrontals = S_setminus_B.size();
|
||||
cachedShortcut_ = //
|
||||
*solver.conditionalBayesNet(keep, nrFrontals, function);
|
||||
assertInvariants();
|
||||
} else {
|
||||
BayesNet<CONDITIONAL> empty;
|
||||
cachedShortcut_ = empty;
|
||||
}
|
||||
}
|
||||
|
||||
// If not the base case, obtain the parent shortcut P(Sp|R) as factors
|
||||
FactorGraph<FactorType> p_Sp_R(parent->shortcut(R, function));
|
||||
|
||||
// now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R)
|
||||
FactorGraph<FactorType> p_Cp_R;
|
||||
p_Cp_R.push_back(p_R);
|
||||
p_Cp_R.push_back(p_Fp_Sp);
|
||||
p_Cp_R.push_back(p_Sp_R);
|
||||
|
||||
// Eliminate into a Bayes net with ordering designed to integrate out
|
||||
// any variables not in *our* separator. Variables to integrate out must be
|
||||
// eliminated first hence the desired ordering is [Cp\S S].
|
||||
// However, an added wrinkle is that Cp might overlap with the root.
|
||||
// Keys corresponding to the root should not be added to the ordering at all.
|
||||
|
||||
if(debug) {
|
||||
p_R.print("p_R: ");
|
||||
p_Fp_Sp.print("p_Fp_Sp: ");
|
||||
p_Sp_R.print("p_Sp_R: ");
|
||||
}
|
||||
|
||||
// We want to factor into a conditional of the clique variables given the
|
||||
// root and the marginal on the root, integrating out all other variables.
|
||||
// The integrands include any parents of this clique and the variables of
|
||||
// the parent clique.
|
||||
FastSet<Index> variablesAtBack;
|
||||
FastSet<Index> separator;
|
||||
size_t uniqueRootVariables = 0;
|
||||
BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) {
|
||||
variablesAtBack.insert(separatorIndex);
|
||||
separator.insert(separatorIndex);
|
||||
if(debug) std::cout << "At back (this): " << separatorIndex << std::endl;
|
||||
}
|
||||
BOOST_FOREACH(const Index key, R->conditional()->keys()) {
|
||||
if(variablesAtBack.insert(key).second)
|
||||
++ uniqueRootVariables;
|
||||
if(debug) std::cout << "At back (root): " << key << std::endl;
|
||||
}
|
||||
|
||||
Permutation toBack = Permutation::PushToBack(
|
||||
std::vector<Index>(variablesAtBack.begin(), variablesAtBack.end()),
|
||||
R->conditional()->lastFrontalKey() + 1);
|
||||
Permutation::shared_ptr toBackInverse(toBack.inverse());
|
||||
BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) {
|
||||
factor->permuteWithInverse(*toBackInverse); }
|
||||
typename BayesNet<ConditionalType>::shared_ptr eliminated(EliminationTree<
|
||||
FactorType>::Create(p_Cp_R)->eliminate(function));
|
||||
|
||||
// Take only the conditionals for p(S|R). We check for each variable being
|
||||
// in the separator set because if some separator variables overlap with
|
||||
// root variables, we cannot rely on the number of root variables, and also
|
||||
// want to include those variables in the conditional.
|
||||
BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) {
|
||||
assert(conditional->nrFrontals() == 1);
|
||||
if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) {
|
||||
if(debug)
|
||||
conditional->print("Taking C|R conditional: ");
|
||||
p_S_R.push_front(conditional);
|
||||
}
|
||||
if(p_S_R.size() == separator.size())
|
||||
break;
|
||||
}
|
||||
|
||||
// Undo the permutation
|
||||
if(debug) toBack.print("toBack: ");
|
||||
p_S_R.permuteWithInverse(toBack);
|
||||
}
|
||||
|
||||
cachedShortcut_ = p_S_R;
|
||||
}
|
||||
else
|
||||
p_S_R = *cachedShortcut_; // return the cached version
|
||||
|
||||
assertInvariants();
|
||||
|
||||
// return the shortcut P(S|R)
|
||||
return p_S_R;
|
||||
// return the shortcut P(S||B)
|
||||
return *cachedShortcut_; // return the cached version
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -216,12 +206,13 @@ namespace gtsam {
|
|||
// Because the root clique could be very big.
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
FactorGraph<typename BayesTreeCliqueBase<DERIVED,CONDITIONAL>::FactorType> BayesTreeCliqueBase<DERIVED,CONDITIONAL>::marginal(
|
||||
derived_ptr R, Eliminate function) const{
|
||||
FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
|
||||
DERIVED, CONDITIONAL>::marginal(derived_ptr R, Eliminate function) const {
|
||||
// If we are the root, just return this root
|
||||
// NOTE: immediately cast to a factor graph
|
||||
BayesNet<ConditionalType> bn(R->conditional());
|
||||
if (R.get()==this) return bn;
|
||||
if (R.get() == this)
|
||||
return bn;
|
||||
|
||||
// Combine P(F|S), P(S|R), and P(R)
|
||||
BayesNet<ConditionalType> p_FSR = this->shortcut(R, function);
|
||||
|
@ -237,16 +228,21 @@ namespace gtsam {
|
|||
// P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
FactorGraph<typename BayesTreeCliqueBase<DERIVED,CONDITIONAL>::FactorType> BayesTreeCliqueBase<DERIVED,CONDITIONAL>::joint(
|
||||
derived_ptr C2, derived_ptr R, Eliminate function) const {
|
||||
FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
|
||||
DERIVED, CONDITIONAL>::joint(derived_ptr C2, derived_ptr R,
|
||||
Eliminate function) const {
|
||||
// For now, assume neither is the root
|
||||
|
||||
// Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
|
||||
FactorGraph<FactorType> joint;
|
||||
if (!isRoot()) joint.push_back(this->conditional()->toFactor()); // P(F1|S1)
|
||||
if (!isRoot()) joint.push_back(shortcut(R, function)); // P(S1|R)
|
||||
if (!C2->isRoot()) joint.push_back(C2->conditional()->toFactor()); // P(F2|S2)
|
||||
if (!C2->isRoot()) joint.push_back(C2->shortcut(R, function)); // P(S2|R)
|
||||
if (!isRoot())
|
||||
joint.push_back(this->conditional()->toFactor()); // P(F1|S1)
|
||||
if (!isRoot())
|
||||
joint.push_back(shortcut(R, function)); // P(S1|R)
|
||||
if (!C2->isRoot())
|
||||
joint.push_back(C2->conditional()->toFactor()); // P(F2|S2)
|
||||
if (!C2->isRoot())
|
||||
joint.push_back(C2->shortcut(R, function)); // P(S2|R)
|
||||
joint.push_back(R->conditional()->toFactor()); // P(R)
|
||||
|
||||
// Find the keys of both C1 and C2
|
||||
|
@ -257,29 +253,30 @@ namespace gtsam {
|
|||
keys12.insert(keys2.begin(), keys2.end());
|
||||
|
||||
// Calculate the marginal
|
||||
std::vector<Index> keys12vector; keys12vector.reserve(keys12.size());
|
||||
std::vector<Index> keys12vector;
|
||||
keys12vector.reserve(keys12.size());
|
||||
keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end());
|
||||
assertInvariants();
|
||||
GenericSequentialSolver<FactorType> solver(joint);
|
||||
return *solver.jointFactorGraph(keys12vector, function);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::deleteCachedShorcuts() {
|
||||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::deleteCachedShorcuts() {
|
||||
|
||||
// When a shortcut is requested, all of the shortcuts between it and the
|
||||
// root are also generated. So, if this clique's cached shortcut is set,
|
||||
// recursively call over all child cliques. Otherwise, it is unnecessary.
|
||||
if(cachedShortcut_) {
|
||||
BOOST_FOREACH(derived_ptr& child, children_) {
|
||||
child->deleteCachedShorcuts();
|
||||
}
|
||||
// When a shortcut is requested, all of the shortcuts between it and the
|
||||
// root are also generated. So, if this clique's cached shortcut is set,
|
||||
// recursively call over all child cliques. Otherwise, it is unnecessary.
|
||||
if (cachedShortcut_) {
|
||||
BOOST_FOREACH(derived_ptr& child, children_) {
|
||||
child->deleteCachedShorcuts();
|
||||
}
|
||||
|
||||
//Delete CachedShortcut for this clique
|
||||
this->resetCachedShortcut();
|
||||
}
|
||||
//Delete CachedShortcut for this clique
|
||||
this->resetCachedShortcut();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -25,7 +25,9 @@
|
|||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
||||
namespace gtsam { template<class CONDITIONAL, class CLIQUE> class BayesTree; }
|
||||
namespace gtsam {
|
||||
template<class CONDITIONAL, class CLIQUE> class BayesTree;
|
||||
}
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -48,7 +50,7 @@ namespace gtsam {
|
|||
struct BayesTreeCliqueBase {
|
||||
|
||||
public:
|
||||
typedef BayesTreeCliqueBase<DERIVED,CONDITIONAL> This;
|
||||
typedef BayesTreeCliqueBase<DERIVED, CONDITIONAL> This;
|
||||
typedef DERIVED DerivedType;
|
||||
typedef CONDITIONAL ConditionalType;
|
||||
typedef boost::shared_ptr<ConditionalType> sharedConditional;
|
||||
|
@ -61,19 +63,22 @@ namespace gtsam {
|
|||
|
||||
protected:
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Default constructor */
|
||||
BayesTreeCliqueBase() {}
|
||||
BayesTreeCliqueBase() {
|
||||
}
|
||||
|
||||
/** Construct from a conditional, leaving parent and child pointers uninitialized */
|
||||
BayesTreeCliqueBase(const sharedConditional& conditional);
|
||||
|
||||
/** Construct from an elimination result, which is a pair<CONDITIONAL,FACTOR> */
|
||||
BayesTreeCliqueBase(const std::pair<sharedConditional, boost::shared_ptr<typename ConditionalType::FactorType> >& result);
|
||||
BayesTreeCliqueBase(
|
||||
const std::pair<sharedConditional,
|
||||
boost::shared_ptr<typename ConditionalType::FactorType> >& result);
|
||||
|
||||
/// @}
|
||||
/// @}
|
||||
|
||||
/// This stores the Cached Shortcut value
|
||||
mutable boost::optional<BayesNet<ConditionalType> > cachedShortcut_;
|
||||
|
@ -83,67 +88,91 @@ namespace gtsam {
|
|||
derived_weak_ptr parent_;
|
||||
std::list<derived_ptr> children_;
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** check equality */
|
||||
bool equals(const This& other, double tol=1e-9) const {
|
||||
return (!conditional_ && !other.conditional()) ||
|
||||
conditional_->equals(*other.conditional(), tol);
|
||||
bool equals(const This& other, double tol = 1e-9) const {
|
||||
return (!conditional_ && !other.conditional())
|
||||
|| conditional_->equals(*other.conditional(), tol);
|
||||
}
|
||||
|
||||
/** print this node */
|
||||
void print(const std::string& s = "", const IndexFormatter& indexFormatter = DefaultIndexFormatter ) const;
|
||||
void print(const std::string& s = "", const IndexFormatter& indexFormatter =
|
||||
DefaultIndexFormatter) const;
|
||||
|
||||
/** print this node and entire subtree below it */
|
||||
void printTree(const std::string& indent="", const IndexFormatter& indexFormatter = DefaultIndexFormatter ) const;
|
||||
void printTree(const std::string& indent = "",
|
||||
const IndexFormatter& indexFormatter = DefaultIndexFormatter) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Access the conditional */
|
||||
const sharedConditional& conditional() const { return conditional_; }
|
||||
const sharedConditional& conditional() const {
|
||||
return conditional_;
|
||||
}
|
||||
|
||||
/** is this the root of a Bayes tree ? */
|
||||
inline bool isRoot() const { return parent_.expired(); }
|
||||
inline bool isRoot() const {
|
||||
return parent_.expired();
|
||||
}
|
||||
|
||||
/** The size of subtree rooted at this clique, i.e., nr of Cliques */
|
||||
size_t treeSize() const;
|
||||
|
||||
/** The arrow operator accesses the conditional */
|
||||
const ConditionalType* operator->() const { return conditional_.get(); }
|
||||
const ConditionalType* operator->() const {
|
||||
return conditional_.get();
|
||||
}
|
||||
|
||||
/** return the const reference of children */
|
||||
const std::list<derived_ptr>& children() const { return children_; }
|
||||
const std::list<derived_ptr>& children() const {
|
||||
return children_;
|
||||
}
|
||||
|
||||
/** return a shared_ptr to the parent clique */
|
||||
derived_ptr parent() const { return parent_.lock(); }
|
||||
derived_ptr parent() const {
|
||||
return parent_.lock();
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/** The arrow operator accesses the conditional */
|
||||
ConditionalType* operator->() { return conditional_.get(); }
|
||||
ConditionalType* operator->() {
|
||||
return conditional_.get();
|
||||
}
|
||||
|
||||
/** return the reference of children non-const version*/
|
||||
std::list<derived_ptr>& children() { return children_; }
|
||||
std::list<derived_ptr>& children() {
|
||||
return children_;
|
||||
}
|
||||
|
||||
/** Construct shared_ptr from a conditional, leaving parent and child pointers uninitialized */
|
||||
static derived_ptr Create(const sharedConditional& conditional) { return boost::make_shared<DerivedType>(conditional); }
|
||||
static derived_ptr Create(const sharedConditional& conditional) {
|
||||
return boost::make_shared<DerivedType>(conditional);
|
||||
}
|
||||
|
||||
/** Construct shared_ptr from a FactorGraph<FACTOR>::EliminationResult. In this class
|
||||
* the conditional part is kept and the factor part is ignored, but in derived clique
|
||||
* types, such as ISAM2Clique, the factor part is kept as a cached factor.
|
||||
* @param result An elimination result, which is a pair<CONDITIONAL,FACTOR>
|
||||
*/
|
||||
static derived_ptr Create(const std::pair<sharedConditional, boost::shared_ptr<typename ConditionalType::FactorType> >& result) { return boost::make_shared<DerivedType>(result); }
|
||||
static derived_ptr Create(
|
||||
const std::pair<sharedConditional,
|
||||
boost::shared_ptr<typename ConditionalType::FactorType> >& result) {
|
||||
return boost::make_shared<DerivedType>(result);
|
||||
}
|
||||
|
||||
/** Returns a new clique containing a copy of the conditional but without
|
||||
* the parent and child clique pointers.
|
||||
*/
|
||||
derived_ptr clone() const { return Create(sharedConditional(new ConditionalType(*conditional_))); }
|
||||
/** Returns a new clique containing a copy of the conditional but without
|
||||
* the parent and child clique pointers.
|
||||
*/
|
||||
derived_ptr clone() const {
|
||||
return Create(sharedConditional(new ConditionalType(*conditional_)));
|
||||
}
|
||||
|
||||
/** Permute the variables in the whole subtree rooted at this clique */
|
||||
void permuteWithInverse(const Permutation& inversePermutation);
|
||||
|
@ -156,13 +185,16 @@ namespace gtsam {
|
|||
bool permuteSeparatorWithInverse(const Permutation& inversePermutation);
|
||||
|
||||
/** return the conditional P(S|Root) on the separator given the root */
|
||||
BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function) const;
|
||||
BayesNet<ConditionalType> shortcut(derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/** return the marginal P(C) of the clique */
|
||||
FactorGraph<FactorType> marginal(derived_ptr root, Eliminate function) const;
|
||||
FactorGraph<FactorType> marginal(derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/** return the joint P(C1,C2), where C1==this. TODO: not a method? */
|
||||
FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root, Eliminate function) const;
|
||||
FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root,
|
||||
Eliminate function) const;
|
||||
|
||||
/**
|
||||
* This deletes the cached shortcuts of all cliques (subtree) below this clique.
|
||||
|
@ -171,29 +203,53 @@ namespace gtsam {
|
|||
void deleteCachedShorcuts();
|
||||
|
||||
/** return cached shortcut of the clique */
|
||||
const boost::optional<BayesNet<ConditionalType> > cachedShortcut() const { return cachedShortcut_; }
|
||||
const boost::optional<BayesNet<ConditionalType> > cachedShortcut() const {
|
||||
return cachedShortcut_;
|
||||
}
|
||||
|
||||
friend class BayesTree<ConditionalType, DerivedType>;
|
||||
friend class BayesTree<ConditionalType, DerivedType> ;
|
||||
|
||||
protected:
|
||||
|
||||
///TODO: comment
|
||||
/// assert invariants that have to hold in a clique
|
||||
void assertInvariants() const;
|
||||
|
||||
/// Calculate set \f$ S \setminus B \f$ for shortcut calculations
|
||||
std::vector<Index> separator_setminus_B(derived_ptr B) const;
|
||||
|
||||
/// Calculate set \f$ S_p \cap B \f$ for shortcut calculations
|
||||
std::vector<Index> parent_separator_intersection_B(derived_ptr B) const;
|
||||
|
||||
/**
|
||||
* Determine variable indices to keep in recursive separator shortcut calculation
|
||||
* The factor graph p_Cp_B has keys from the parent clique Cp and from B.
|
||||
* But we only keep the variables not in S union B.
|
||||
*/
|
||||
std::vector<Index> shortcut_indices(derived_ptr B,
|
||||
const FactorGraph<FactorType>& p_Cp_B) const;
|
||||
|
||||
/// Reset the computed shortcut of this clique. Used by friend BayesTree
|
||||
void resetCachedShortcut() { cachedShortcut_ = boost::none; }
|
||||
void resetCachedShortcut() {
|
||||
cachedShortcut_ = boost::none;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
/** Cliques cannot be copied except by the clone() method, which does not
|
||||
/**
|
||||
* Cliques cannot be copied except by the clone() method, which does not
|
||||
* copy the parent and child pointers.
|
||||
*/
|
||||
BayesTreeCliqueBase(const This& other) { assert(false); }
|
||||
BayesTreeCliqueBase(const This& other) {
|
||||
assert(false);
|
||||
}
|
||||
|
||||
/** Cliques cannot be copied except by the clone() method, which does not
|
||||
* copy the parent and child pointers.
|
||||
*/
|
||||
This& operator=(const This& other) { assert(false); return *this; }
|
||||
This& operator=(const This& other) {
|
||||
assert(false);
|
||||
return *this;
|
||||
}
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
@ -204,17 +260,19 @@ namespace gtsam {
|
|||
ar & BOOST_SERIALIZATION_NVP(children_);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @}
|
||||
|
||||
}; // \struct Clique
|
||||
};
|
||||
// \struct Clique
|
||||
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
const DERIVED* asDerived(const BayesTreeCliqueBase<DERIVED,CONDITIONAL>* base) {
|
||||
const DERIVED* asDerived(
|
||||
const BayesTreeCliqueBase<DERIVED, CONDITIONAL>* base) {
|
||||
return static_cast<const DERIVED*>(base);
|
||||
}
|
||||
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
DERIVED* asDerived(BayesTreeCliqueBase<DERIVED,CONDITIONAL>* base) {
|
||||
DERIVED* asDerived(BayesTreeCliqueBase<DERIVED, CONDITIONAL>* base) {
|
||||
return static_cast<DERIVED*>(base);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,278 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testSymbolicBayesTree.cpp
|
||||
* @date sept 15, 2012
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/inference/SymbolicSequentialSolver.h>
|
||||
#include <gtsam/inference/SymbolicFactorGraph.h>
|
||||
#include <gtsam/inference/BayesTree.h>
|
||||
|
||||
#include <boost/assign/list_of.hpp>
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static bool debug = false;
|
||||
|
||||
typedef BayesNet<IndexConditional> SymbolicBayesNet;
|
||||
typedef BayesTree<IndexConditional> SymbolicBayesTree;
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
TEST_UNSAFE( SymbolicBayesTree, thinTree ) {
|
||||
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
SymbolicBayesNet bayesNet;
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(14));
|
||||
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(13, 14));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(12, 14));
|
||||
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(11, 13, 14));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(10, 13, 14));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(9, 12, 14));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(8, 12, 14));
|
||||
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(7, 11, 13));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(6, 11, 13));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(5, 10, 13));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(4, 10, 13));
|
||||
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(3, 9, 12));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(2, 9, 12));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(1, 8, 12));
|
||||
bayesNet.push_front(boost::make_shared<IndexConditional>(0, 8, 12));
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesNet);
|
||||
bayesNet.saveGraph("/tmp/symbolicBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
SymbolicBayesTree bayesTree(bayesNet);
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesTree);
|
||||
bayesTree.saveGraph("/tmp/symbolicBayesTree.dot");
|
||||
}
|
||||
|
||||
SymbolicBayesTree::Clique::shared_ptr R = bayesTree.root();
|
||||
|
||||
{
|
||||
// check shortcut P(S9||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[9];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S8||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[8];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(12, 14));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S4||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[4];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(10, 13, 14));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S0||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(12, 14));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(8, 12, 14));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* *
|
||||
Bayes tree for smoother with "natural" ordering:
|
||||
C1 5 6
|
||||
C2 4 : 5
|
||||
C3 3 : 4
|
||||
C4 2 : 3
|
||||
C5 1 : 2
|
||||
C6 0 : 1
|
||||
**************************************************************************** */
|
||||
|
||||
TEST_UNSAFE( SymbolicBayesTree, linear_smoother_shortcuts ) {
|
||||
// Create smoother with 7 nodes
|
||||
SymbolicFactorGraph smoother;
|
||||
smoother.push_factor(0);
|
||||
smoother.push_factor(0, 1);
|
||||
smoother.push_factor(1, 2);
|
||||
smoother.push_factor(2, 3);
|
||||
smoother.push_factor(3, 4);
|
||||
smoother.push_factor(4, 5);
|
||||
smoother.push_factor(5, 6);
|
||||
|
||||
BayesNet<IndexConditional> bayesNet =
|
||||
*SymbolicSequentialSolver(smoother).eliminate();
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesNet);
|
||||
bayesNet.saveGraph("/tmp/symbolicBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
SymbolicBayesTree bayesTree(bayesNet);
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesTree);
|
||||
bayesTree.saveGraph("/tmp/symbolicBayesTree.dot");
|
||||
}
|
||||
|
||||
SymbolicBayesTree::Clique::shared_ptr R = bayesTree.root();
|
||||
|
||||
{
|
||||
// check shortcut P(S2||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[4]; // 4 is frontal in C2
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S3||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[3]; // 3 is frontal in C3
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(4, 5));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
{
|
||||
// check shortcut P(S4||R) to root
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bayesTree[2]; // 2 is frontal in C4
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(3, 5));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// from testSymbolicJunctionTree, which failed at one point
|
||||
TEST(SymbolicBayesTree, complicatedMarginal) {
|
||||
|
||||
// Create the conditionals to go in the BayesTree
|
||||
list<Index> L;
|
||||
L = list_of(1)(2)(5);
|
||||
IndexConditional::shared_ptr R_1_2(new IndexConditional(L, 2));
|
||||
L = list_of(3)(4)(6);
|
||||
IndexConditional::shared_ptr R_3_4(new IndexConditional(L, 2));
|
||||
L = list_of(5)(6)(7)(8);
|
||||
IndexConditional::shared_ptr R_5_6(new IndexConditional(L, 2));
|
||||
L = list_of(7)(8)(11);
|
||||
IndexConditional::shared_ptr R_7_8(new IndexConditional(L, 2));
|
||||
L = list_of(9)(10)(11)(12);
|
||||
IndexConditional::shared_ptr R_9_10(new IndexConditional(L, 2));
|
||||
L = list_of(11)(12);
|
||||
IndexConditional::shared_ptr R_11_12(new IndexConditional(L, 2));
|
||||
|
||||
// Symbolic Bayes Tree
|
||||
typedef SymbolicBayesTree::Clique Clique;
|
||||
typedef SymbolicBayesTree::sharedClique sharedClique;
|
||||
|
||||
// Create Bayes Tree
|
||||
SymbolicBayesTree bt;
|
||||
bt.insert(sharedClique(new Clique(R_11_12)));
|
||||
bt.insert(sharedClique(new Clique(R_9_10)));
|
||||
bt.insert(sharedClique(new Clique(R_7_8)));
|
||||
bt.insert(sharedClique(new Clique(R_5_6)));
|
||||
bt.insert(sharedClique(new Clique(R_3_4)));
|
||||
bt.insert(sharedClique(new Clique(R_1_2)));
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bt);
|
||||
bt.saveGraph("/tmp/symbolicBayesTree.dot");
|
||||
}
|
||||
|
||||
SymbolicBayesTree::Clique::shared_ptr R = bt.root();
|
||||
SymbolicBayesNet empty;
|
||||
|
||||
// Shortcut on 9
|
||||
{
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bt[9];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
EXPECT(assert_equal(empty, shortcut));
|
||||
}
|
||||
|
||||
// Shortcut on 7
|
||||
{
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bt[7];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
EXPECT(assert_equal(empty, shortcut));
|
||||
}
|
||||
|
||||
// Shortcut on 5
|
||||
{
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bt[5];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(8, 11));
|
||||
expected.push_front(boost::make_shared<IndexConditional>(7, 8, 11));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
// Shortcut on 3
|
||||
{
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bt[3];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(6, 11));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
// Shortcut on 1
|
||||
{
|
||||
SymbolicBayesTree::Clique::shared_ptr c = bt[1];
|
||||
SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic);
|
||||
SymbolicBayesNet expected;
|
||||
expected.push_front(boost::make_shared<IndexConditional>(5, 11));
|
||||
EXPECT(assert_equal(expected, shortcut));
|
||||
}
|
||||
|
||||
// Marginal on 5
|
||||
{
|
||||
IndexFactor::shared_ptr actual = bt.marginalFactor(5, EliminateSymbolic);
|
||||
EXPECT(assert_equal(IndexFactor(5), *actual, 1e-1));
|
||||
}
|
||||
|
||||
// Shortcut on 6
|
||||
{
|
||||
IndexFactor::shared_ptr actual = bt.marginalFactor(6, EliminateSymbolic);
|
||||
EXPECT(assert_equal(IndexFactor(6), *actual, 1e-1));
|
||||
}
|
||||
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
Loading…
Reference in New Issue