diff --git a/.cproject b/.cproject index 766138423..52375afab 100644 --- a/.cproject +++ b/.cproject @@ -863,6 +863,14 @@ true true + + make + -j1 + testSymbolicBayesTree.run + true + false + true + make -j2 diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 90a294893..8ff50ad3c 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -36,40 +36,6 @@ class Clique: public BayesTreeCliqueBase { protected: - /// Calculate set S\B - vector separatorShortcutVariables(derived_ptr B) const { - sharedConditional p_F_S = this->conditional(); - vector &indicesB = B->conditional()->keys(); - vector S_setminus_B; - set_difference(p_F_S->beginParents(), p_F_S->endParents(), // - indicesB.begin(), indicesB.end(), back_inserter(S_setminus_B)); - return S_setminus_B; - } - - /** - * Determine variable indices to keep in recursive separator shortcut calculation - * The factor graph p_Cp_B has keys from the parent clique Cp and from B. - * But we only keep the variables not in S union B. - */ - vector indices(derived_ptr B, - const FactorGraph& p_Cp_B) const { - - // We do this by first merging S and B - sharedConditional p_F_S = this->conditional(); - vector &indicesB = B->conditional()->keys(); - vector S_union_B; - set_union(p_F_S->beginParents(), p_F_S->endParents(), // - indicesB.begin(), indicesB.end(), back_inserter(S_union_B)); - - // then intersecting S_union_B with all keys in p_Cp_B - set allKeys = p_Cp_B.keys(); - vector keepers; - set_intersection(S_union_B.begin(), S_union_B.end(), // - allKeys.begin(), allKeys.end(), back_inserter(keepers)); - - return keepers; - } - public: typedef BayesTreeCliqueBase Base; @@ -102,88 +68,6 @@ public: return result; } - /** - * Separator shortcut function P(S||B) = P(S\B|B) - * where S is a clique separator, and B any node (e.g., a brancing in the tree) - * We can compute it recursively from the parent shortcut - * P(Sp||B) as \int P(Fp|Sp) P(Sp||B), where Fp are the frontal nodes in p - */ - FactorGraph::shared_ptr separatorShortcut(derived_ptr B) const { - - typedef FactorGraph FG; - - FG::shared_ptr p_S_B; //shortcut P(S||B) This is empty now - - // We only calculate the shortcut when this clique is not B - // and when the S\B is not empty - vector S_setminus_B = separatorShortcutVariables(B); - if (B.get() != this && !S_setminus_B.empty()) { - - // Obtain P(Fp|Sp) as a factor - derived_ptr parent(parent_.lock()); - boost::shared_ptr p_Fp_Sp = parent->conditional()->toFactor(); - - // Obtain the parent shortcut P(Sp|B) as factors - // TODO: really annoying that we eliminate more than we have to ! - // TODO: we should only eliminate C_p\B, with S\B variables last - // TODO: and this index dance will be easier then, as well - FG p_Sp_B(parent->shortcut(B, &EliminateDiscrete)); - - // now combine P(Cp||B) = P(Fp|Sp) * P(Sp||B) - boost::shared_ptr p_Cp_B(new FG); - p_Cp_B->push_back(p_Fp_Sp); - p_Cp_B->push_back(p_Sp_B); - - // Figure out how many variables there are in in the shortcut -// size_t nVariables = *max_element(S_setminus_B.begin(),S_setminus_B.end()); -// cout << "nVariables: " << nVariables << endl; -// VariableIndex::shared_ptr structure(new VariableIndex(*p_Cp_B)); -// GTSAM_PRINT(*p_Cp_B); -// GTSAM_PRINT(*structure); - - // Create a generic solver that will marginalize for us - GenericSequentialSolver solver(*p_Cp_B); - - // The factor graph above will have keys from the parent clique Cp and from B. - // But we only keep the variables not in S union B. - vector keepers = indices(B, *p_Cp_B); - - p_S_B = solver.jointFactorGraph(keepers, &EliminateDiscrete); - } - // return the shortcut P(S||B) - return p_S_B; - } - - /** - * The shortcut density is a conditional P(S||B) of the separator of this - * clique on the clique B. - */ - BayesNet shortcut(derived_ptr B, - Eliminate function) const { - - //Check if the ShortCut already exists - if (cachedShortcut_) { - return *cachedShortcut_; // return the cached version - } else { - BayesNet bn; - FactorGraph::shared_ptr fg = separatorShortcut(B); - if (fg) { - // calculate set S\B of indices to keep in Bayes net - vector S_setminus_B = separatorShortcutVariables(B); - set keep(S_setminus_B.begin(), S_setminus_B.end()); - - BOOST_FOREACH (FactorType::shared_ptr factor,*fg) { - DecisionTreeFactor::shared_ptr df = boost::dynamic_pointer_cast< - DecisionTreeFactor>(factor); - if (keep.count(*factor->begin())) - bn.push_front(boost::make_shared(1, *df)); - } - } - cachedShortcut_ = bn; - return bn; - } - } - }; typedef BayesTree DiscreteBayesTree; @@ -196,7 +80,7 @@ double evaluate(const DiscreteBayesTree& tree, /* ************************************************************************* */ -TEST_UNSAFE( DiscreteMarginals, thinTree ) { +TEST_UNSAFE( DiscreteBayesTree, thinTree ) { const int nrNodes = 15; const size_t nrStates = 2; diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index 7ca40345b..5ac87b6c3 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -22,7 +22,7 @@ namespace gtsam { /* ************************************************************************* */ template - void BayesTreeCliqueBase::assertInvariants() const { + void BayesTreeCliqueBase::assertInvariants() const { #ifndef NDEBUG // We rely on the keys being sorted // FastVector sortedUniqueKeys(conditional_->begin(), conditional_->end()); @@ -35,27 +35,71 @@ namespace gtsam { /* ************************************************************************* */ template - BayesTreeCliqueBase::BayesTreeCliqueBase(const sharedConditional& conditional) : - conditional_(conditional) { + std::vector BayesTreeCliqueBase::separator_setminus_B( + derived_ptr B) const { + sharedConditional p_F_S = this->conditional(); + std::vector &indicesB = B->conditional()->keys(); + std::vector S_setminus_B; + std::set_difference(p_F_S->beginParents(), p_F_S->endParents(), // + indicesB.begin(), indicesB.end(), back_inserter(S_setminus_B)); + return S_setminus_B; + } + + /* ************************************************************************* */ + template + std::vector BayesTreeCliqueBase::shortcut_indices( + derived_ptr B, const FactorGraph& p_Cp_B) const { + std::set allKeys = p_Cp_B.keys(); + std::vector &indicesB = B->conditional()->keys(); + std::vector keep; +#ifdef OLD_INDICES + // We do this by first merging S and B + sharedConditional p_F_S = this->conditional(); + std::vector S_union_B; + std::set_union(p_F_S->beginParents(), p_F_S->endParents(),// + indicesB.begin(), indicesB.end(), back_inserter(S_union_B)); + + // then intersecting S_union_B with all keys in p_Cp_B + std::set_intersection(S_union_B.begin(), S_union_B.end(),// + allKeys.begin(), allKeys.end(), back_inserter(keep)); +#else + std::vector S_setminus_B = separator_setminus_B(B); // TODO, get as argument? + std::set_intersection(S_setminus_B.begin(), S_setminus_B.end(), // + allKeys.begin(), allKeys.end(), back_inserter(keep)); + std::set_intersection(indicesB.begin(), indicesB.end(), // + allKeys.begin(), allKeys.end(), back_inserter(keep)); +#endif + // BOOST_FOREACH(Index j, keep) std::cout << j << " "; std::cout << std::endl; + return keep; + } + + /* ************************************************************************* */ + template + BayesTreeCliqueBase::BayesTreeCliqueBase( + const sharedConditional& conditional) : + conditional_(conditional) { assertInvariants(); } /* ************************************************************************* */ template - BayesTreeCliqueBase::BayesTreeCliqueBase(const std::pair >& result) : - conditional_(result.first) { + BayesTreeCliqueBase::BayesTreeCliqueBase( + const std::pair >& result) : + conditional_(result.first) { assertInvariants(); } /* ************************************************************************* */ template - void BayesTreeCliqueBase::print(const std::string& s, const IndexFormatter& indexFormatter) const { + void BayesTreeCliqueBase::print(const std::string& s, + const IndexFormatter& indexFormatter) const { conditional_->print(s, indexFormatter); } /* ************************************************************************* */ template - size_t BayesTreeCliqueBase::treeSize() const { + size_t BayesTreeCliqueBase::treeSize() const { size_t size = 1; BOOST_FOREACH(const derived_ptr& child, children_) size += child->treeSize(); @@ -64,15 +108,17 @@ namespace gtsam { /* ************************************************************************* */ template - void BayesTreeCliqueBase::printTree(const std::string& indent, const IndexFormatter& indexFormatter) const { + void BayesTreeCliqueBase::printTree( + const std::string& indent, const IndexFormatter& indexFormatter) const { asDerived(this)->print(indent, indexFormatter); BOOST_FOREACH(const derived_ptr& child, children_) - child->printTree(indent+" ", indexFormatter); + child->printTree(indent + " ", indexFormatter); } /* ************************************************************************* */ template - void BayesTreeCliqueBase::permuteWithInverse(const Permutation& inversePermutation) { + void BayesTreeCliqueBase::permuteWithInverse( + const Permutation& inversePermutation) { conditional_->permuteWithInverse(inversePermutation); BOOST_FOREACH(const derived_ptr& child, children_) { child->permuteWithInverse(inversePermutation); @@ -82,19 +128,21 @@ namespace gtsam { /* ************************************************************************* */ template - bool BayesTreeCliqueBase::permuteSeparatorWithInverse(const Permutation& inversePermutation) { - bool changed = conditional_->permuteSeparatorWithInverse(inversePermutation); + bool BayesTreeCliqueBase::permuteSeparatorWithInverse( + const Permutation& inversePermutation) { + bool changed = conditional_->permuteSeparatorWithInverse( + inversePermutation); #ifndef NDEBUG if(!changed) { - BOOST_FOREACH(Index& separatorKey, conditional_->parents()) { assert(separatorKey == inversePermutation[separatorKey]); } + BOOST_FOREACH(Index& separatorKey, conditional_->parents()) {assert(separatorKey == inversePermutation[separatorKey]);} BOOST_FOREACH(const derived_ptr& child, children_) { assert(child->permuteSeparatorWithInverse(inversePermutation) == false); } } #endif - if(changed) { + if (changed) { BOOST_FOREACH(const derived_ptr& child, children_) { - (void)child->permuteSeparatorWithInverse(inversePermutation); + (void) child->permuteSeparatorWithInverse(inversePermutation); } } assertInvariants(); @@ -108,105 +156,47 @@ namespace gtsam { /* ************************************************************************* */ template BayesNet BayesTreeCliqueBase::shortcut( - derived_ptr R, Eliminate function) const{ + derived_ptr B, Eliminate function) const { - static const bool debug = false; + // Check if the ShortCut already exists + if (!cachedShortcut_) { - BayesNet p_S_R; //shortcut P(S|R) This is empty now + // We only calculate the shortcut when this clique is not B + // and when the S\B is not empty + std::vector S_setminus_B = separator_setminus_B(B); + if (B.get() != this && !S_setminus_B.empty()) { - //Check if the ShortCut already exists - if(!cachedShortcut_){ + // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph + derived_ptr parent(parent_.lock()); + FactorGraph p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) + p_Cp_B.push_back(parent->conditional()->toFactor()); // P(Fp|Sp) - // A first base case is when this clique or its parent is the root, - // in which case we return an empty Bayes net. + // Add the root conditional + // TODO: this is needed because otherwise we will be solving singular + // systems and exceptions are thrown. However, we should be able to omit + // this if we can get ATTEMPT_AT_NOT_ELIMINATING_ALL in + // GenericSequentialSolver.* working... + p_Cp_B.push_back(B->conditional()->toFactor()); // P(B) - derived_ptr parent(parent_.lock()); - if (R.get() != this && parent != R) { + // Create solver that will marginalize for us + GenericSequentialSolver solver(p_Cp_B); - // The root conditional - FactorGraph p_R(BayesNet(R->conditional())); + // Determine the variables we want to keep + std::vector keep = shortcut_indices(B, p_Cp_B); - // The parent clique has a ConditionalType for each frontal node in Fp - // so we can obtain P(Fp|Sp) in factor graph form - FactorGraph p_Fp_Sp(BayesNet(parent->conditional())); + // Finally, we only want to have S\B variables in the Bayes net, so + size_t nrFrontals = S_setminus_B.size(); + cachedShortcut_ = // + *solver.conditionalBayesNet(keep, nrFrontals, function); + assertInvariants(); + } else { + BayesNet empty; + cachedShortcut_ = empty; + } + } - // If not the base case, obtain the parent shortcut P(Sp|R) as factors - FactorGraph p_Sp_R(parent->shortcut(R, function)); - - // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) - FactorGraph p_Cp_R; - p_Cp_R.push_back(p_R); - p_Cp_R.push_back(p_Fp_Sp); - p_Cp_R.push_back(p_Sp_R); - - // Eliminate into a Bayes net with ordering designed to integrate out - // any variables not in *our* separator. Variables to integrate out must be - // eliminated first hence the desired ordering is [Cp\S S]. - // However, an added wrinkle is that Cp might overlap with the root. - // Keys corresponding to the root should not be added to the ordering at all. - - if(debug) { - p_R.print("p_R: "); - p_Fp_Sp.print("p_Fp_Sp: "); - p_Sp_R.print("p_Sp_R: "); - } - - // We want to factor into a conditional of the clique variables given the - // root and the marginal on the root, integrating out all other variables. - // The integrands include any parents of this clique and the variables of - // the parent clique. - FastSet variablesAtBack; - FastSet separator; - size_t uniqueRootVariables = 0; - BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) { - variablesAtBack.insert(separatorIndex); - separator.insert(separatorIndex); - if(debug) std::cout << "At back (this): " << separatorIndex << std::endl; - } - BOOST_FOREACH(const Index key, R->conditional()->keys()) { - if(variablesAtBack.insert(key).second) - ++ uniqueRootVariables; - if(debug) std::cout << "At back (root): " << key << std::endl; - } - - Permutation toBack = Permutation::PushToBack( - std::vector(variablesAtBack.begin(), variablesAtBack.end()), - R->conditional()->lastFrontalKey() + 1); - Permutation::shared_ptr toBackInverse(toBack.inverse()); - BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) { - factor->permuteWithInverse(*toBackInverse); } - typename BayesNet::shared_ptr eliminated(EliminationTree< - FactorType>::Create(p_Cp_R)->eliminate(function)); - - // Take only the conditionals for p(S|R). We check for each variable being - // in the separator set because if some separator variables overlap with - // root variables, we cannot rely on the number of root variables, and also - // want to include those variables in the conditional. - BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) { - assert(conditional->nrFrontals() == 1); - if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) { - if(debug) - conditional->print("Taking C|R conditional: "); - p_S_R.push_front(conditional); - } - if(p_S_R.size() == separator.size()) - break; - } - - // Undo the permutation - if(debug) toBack.print("toBack: "); - p_S_R.permuteWithInverse(toBack); - } - - cachedShortcut_ = p_S_R; - } - else - p_S_R = *cachedShortcut_; // return the cached version - - assertInvariants(); - - // return the shortcut P(S|R) - return p_S_R; + // return the shortcut P(S||B) + return *cachedShortcut_; // return the cached version } /* ************************************************************************* */ @@ -216,12 +206,13 @@ namespace gtsam { // Because the root clique could be very big. /* ************************************************************************* */ template - FactorGraph::FactorType> BayesTreeCliqueBase::marginal( - derived_ptr R, Eliminate function) const{ + FactorGraph::FactorType> BayesTreeCliqueBase< + DERIVED, CONDITIONAL>::marginal(derived_ptr R, Eliminate function) const { // If we are the root, just return this root // NOTE: immediately cast to a factor graph BayesNet bn(R->conditional()); - if (R.get()==this) return bn; + if (R.get() == this) + return bn; // Combine P(F|S), P(S|R), and P(R) BayesNet p_FSR = this->shortcut(R, function); @@ -237,16 +228,21 @@ namespace gtsam { // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) /* ************************************************************************* */ template - FactorGraph::FactorType> BayesTreeCliqueBase::joint( - derived_ptr C2, derived_ptr R, Eliminate function) const { + FactorGraph::FactorType> BayesTreeCliqueBase< + DERIVED, CONDITIONAL>::joint(derived_ptr C2, derived_ptr R, + Eliminate function) const { // For now, assume neither is the root // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) FactorGraph joint; - if (!isRoot()) joint.push_back(this->conditional()->toFactor()); // P(F1|S1) - if (!isRoot()) joint.push_back(shortcut(R, function)); // P(S1|R) - if (!C2->isRoot()) joint.push_back(C2->conditional()->toFactor()); // P(F2|S2) - if (!C2->isRoot()) joint.push_back(C2->shortcut(R, function)); // P(S2|R) + if (!isRoot()) + joint.push_back(this->conditional()->toFactor()); // P(F1|S1) + if (!isRoot()) + joint.push_back(shortcut(R, function)); // P(S1|R) + if (!C2->isRoot()) + joint.push_back(C2->conditional()->toFactor()); // P(F2|S2) + if (!C2->isRoot()) + joint.push_back(C2->shortcut(R, function)); // P(S2|R) joint.push_back(R->conditional()->toFactor()); // P(R) // Find the keys of both C1 and C2 @@ -257,29 +253,30 @@ namespace gtsam { keys12.insert(keys2.begin(), keys2.end()); // Calculate the marginal - std::vector keys12vector; keys12vector.reserve(keys12.size()); + std::vector keys12vector; + keys12vector.reserve(keys12.size()); keys12vector.insert(keys12vector.begin(), keys12.begin(), keys12.end()); assertInvariants(); GenericSequentialSolver solver(joint); return *solver.jointFactorGraph(keys12vector, function); } - /* ************************************************************************* */ - template - void BayesTreeCliqueBase::deleteCachedShorcuts() { + /* ************************************************************************* */ + template + void BayesTreeCliqueBase::deleteCachedShorcuts() { - // When a shortcut is requested, all of the shortcuts between it and the - // root are also generated. So, if this clique's cached shortcut is set, - // recursively call over all child cliques. Otherwise, it is unnecessary. - if(cachedShortcut_) { - BOOST_FOREACH(derived_ptr& child, children_) { - child->deleteCachedShorcuts(); - } + // When a shortcut is requested, all of the shortcuts between it and the + // root are also generated. So, if this clique's cached shortcut is set, + // recursively call over all child cliques. Otherwise, it is unnecessary. + if (cachedShortcut_) { + BOOST_FOREACH(derived_ptr& child, children_) { + child->deleteCachedShorcuts(); + } - //Delete CachedShortcut for this clique - this->resetCachedShortcut(); - } + //Delete CachedShortcut for this clique + this->resetCachedShortcut(); + } - } + } } diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index a2fb8feef..ceeef3f45 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -25,7 +25,9 @@ #include #include -namespace gtsam { template class BayesTree; } +namespace gtsam { + template class BayesTree; +} namespace gtsam { @@ -48,7 +50,7 @@ namespace gtsam { struct BayesTreeCliqueBase { public: - typedef BayesTreeCliqueBase This; + typedef BayesTreeCliqueBase This; typedef DERIVED DerivedType; typedef CONDITIONAL ConditionalType; typedef boost::shared_ptr sharedConditional; @@ -61,19 +63,22 @@ namespace gtsam { protected: - /// @name Standard Constructors - /// @{ + /// @name Standard Constructors + /// @{ /** Default constructor */ - BayesTreeCliqueBase() {} + BayesTreeCliqueBase() { + } /** Construct from a conditional, leaving parent and child pointers uninitialized */ BayesTreeCliqueBase(const sharedConditional& conditional); /** Construct from an elimination result, which is a pair */ - BayesTreeCliqueBase(const std::pair >& result); + BayesTreeCliqueBase( + const std::pair >& result); - /// @} + /// @} /// This stores the Cached Shortcut value mutable boost::optional > cachedShortcut_; @@ -83,67 +88,91 @@ namespace gtsam { derived_weak_ptr parent_; std::list children_; - /// @name Testable - /// @{ + /// @name Testable + /// @{ /** check equality */ - bool equals(const This& other, double tol=1e-9) const { - return (!conditional_ && !other.conditional()) || - conditional_->equals(*other.conditional(), tol); + bool equals(const This& other, double tol = 1e-9) const { + return (!conditional_ && !other.conditional()) + || conditional_->equals(*other.conditional(), tol); } /** print this node */ - void print(const std::string& s = "", const IndexFormatter& indexFormatter = DefaultIndexFormatter ) const; + void print(const std::string& s = "", const IndexFormatter& indexFormatter = + DefaultIndexFormatter) const; /** print this node and entire subtree below it */ - void printTree(const std::string& indent="", const IndexFormatter& indexFormatter = DefaultIndexFormatter ) const; + void printTree(const std::string& indent = "", + const IndexFormatter& indexFormatter = DefaultIndexFormatter) const; - /// @} - /// @name Standard Interface - /// @{ + /// @} + /// @name Standard Interface + /// @{ /** Access the conditional */ - const sharedConditional& conditional() const { return conditional_; } + const sharedConditional& conditional() const { + return conditional_; + } /** is this the root of a Bayes tree ? */ - inline bool isRoot() const { return parent_.expired(); } + inline bool isRoot() const { + return parent_.expired(); + } /** The size of subtree rooted at this clique, i.e., nr of Cliques */ size_t treeSize() const; /** The arrow operator accesses the conditional */ - const ConditionalType* operator->() const { return conditional_.get(); } + const ConditionalType* operator->() const { + return conditional_.get(); + } /** return the const reference of children */ - const std::list& children() const { return children_; } + const std::list& children() const { + return children_; + } /** return a shared_ptr to the parent clique */ - derived_ptr parent() const { return parent_.lock(); } + derived_ptr parent() const { + return parent_.lock(); + } - /// @} - /// @name Advanced Interface - /// @{ + /// @} + /// @name Advanced Interface + /// @{ /** The arrow operator accesses the conditional */ - ConditionalType* operator->() { return conditional_.get(); } + ConditionalType* operator->() { + return conditional_.get(); + } /** return the reference of children non-const version*/ - std::list& children() { return children_; } + std::list& children() { + return children_; + } /** Construct shared_ptr from a conditional, leaving parent and child pointers uninitialized */ - static derived_ptr Create(const sharedConditional& conditional) { return boost::make_shared(conditional); } + static derived_ptr Create(const sharedConditional& conditional) { + return boost::make_shared(conditional); + } /** Construct shared_ptr from a FactorGraph::EliminationResult. In this class * the conditional part is kept and the factor part is ignored, but in derived clique * types, such as ISAM2Clique, the factor part is kept as a cached factor. * @param result An elimination result, which is a pair */ - static derived_ptr Create(const std::pair >& result) { return boost::make_shared(result); } + static derived_ptr Create( + const std::pair >& result) { + return boost::make_shared(result); + } - /** Returns a new clique containing a copy of the conditional but without - * the parent and child clique pointers. - */ - derived_ptr clone() const { return Create(sharedConditional(new ConditionalType(*conditional_))); } + /** Returns a new clique containing a copy of the conditional but without + * the parent and child clique pointers. + */ + derived_ptr clone() const { + return Create(sharedConditional(new ConditionalType(*conditional_))); + } /** Permute the variables in the whole subtree rooted at this clique */ void permuteWithInverse(const Permutation& inversePermutation); @@ -156,13 +185,16 @@ namespace gtsam { bool permuteSeparatorWithInverse(const Permutation& inversePermutation); /** return the conditional P(S|Root) on the separator given the root */ - BayesNet shortcut(derived_ptr root, Eliminate function) const; + BayesNet shortcut(derived_ptr root, + Eliminate function) const; /** return the marginal P(C) of the clique */ - FactorGraph marginal(derived_ptr root, Eliminate function) const; + FactorGraph marginal(derived_ptr root, + Eliminate function) const; /** return the joint P(C1,C2), where C1==this. TODO: not a method? */ - FactorGraph joint(derived_ptr C2, derived_ptr root, Eliminate function) const; + FactorGraph joint(derived_ptr C2, derived_ptr root, + Eliminate function) const; /** * This deletes the cached shortcuts of all cliques (subtree) below this clique. @@ -171,29 +203,53 @@ namespace gtsam { void deleteCachedShorcuts(); /** return cached shortcut of the clique */ - const boost::optional > cachedShortcut() const { return cachedShortcut_; } + const boost::optional > cachedShortcut() const { + return cachedShortcut_; + } - friend class BayesTree; + friend class BayesTree ; protected: - ///TODO: comment + /// assert invariants that have to hold in a clique void assertInvariants() const; + /// Calculate set \f$ S \setminus B \f$ for shortcut calculations + std::vector separator_setminus_B(derived_ptr B) const; + + /// Calculate set \f$ S_p \cap B \f$ for shortcut calculations + std::vector parent_separator_intersection_B(derived_ptr B) const; + + /** + * Determine variable indices to keep in recursive separator shortcut calculation + * The factor graph p_Cp_B has keys from the parent clique Cp and from B. + * But we only keep the variables not in S union B. + */ + std::vector shortcut_indices(derived_ptr B, + const FactorGraph& p_Cp_B) const; + /// Reset the computed shortcut of this clique. Used by friend BayesTree - void resetCachedShortcut() { cachedShortcut_ = boost::none; } + void resetCachedShortcut() { + cachedShortcut_ = boost::none; + } private: - /** Cliques cannot be copied except by the clone() method, which does not + /** + * Cliques cannot be copied except by the clone() method, which does not * copy the parent and child pointers. */ - BayesTreeCliqueBase(const This& other) { assert(false); } + BayesTreeCliqueBase(const This& other) { + assert(false); + } /** Cliques cannot be copied except by the clone() method, which does not * copy the parent and child pointers. */ - This& operator=(const This& other) { assert(false); return *this; } + This& operator=(const This& other) { + assert(false); + return *this; + } /** Serialization function */ friend class boost::serialization::access; @@ -204,17 +260,19 @@ namespace gtsam { ar & BOOST_SERIALIZATION_NVP(children_); } - /// @} + /// @} - }; // \struct Clique + }; + // \struct Clique template - const DERIVED* asDerived(const BayesTreeCliqueBase* base) { + const DERIVED* asDerived( + const BayesTreeCliqueBase* base) { return static_cast(base); } template - DERIVED* asDerived(BayesTreeCliqueBase* base) { + DERIVED* asDerived(BayesTreeCliqueBase* base) { return static_cast(base); } diff --git a/gtsam/inference/tests/testSymbolicBayesTree.cpp b/gtsam/inference/tests/testSymbolicBayesTree.cpp new file mode 100644 index 000000000..5e00e23b1 --- /dev/null +++ b/gtsam/inference/tests/testSymbolicBayesTree.cpp @@ -0,0 +1,278 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testSymbolicBayesTree.cpp + * @date sept 15, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include +using namespace boost::assign; + +#include + +using namespace std; +using namespace gtsam; + +static bool debug = false; + +typedef BayesNet SymbolicBayesNet; +typedef BayesTree SymbolicBayesTree; + +/* ************************************************************************* */ + +TEST_UNSAFE( SymbolicBayesTree, thinTree ) { + + // create a thin-tree Bayesnet, a la Jean-Guillaume + SymbolicBayesNet bayesNet; + bayesNet.push_front(boost::make_shared(14)); + + bayesNet.push_front(boost::make_shared(13, 14)); + bayesNet.push_front(boost::make_shared(12, 14)); + + bayesNet.push_front(boost::make_shared(11, 13, 14)); + bayesNet.push_front(boost::make_shared(10, 13, 14)); + bayesNet.push_front(boost::make_shared(9, 12, 14)); + bayesNet.push_front(boost::make_shared(8, 12, 14)); + + bayesNet.push_front(boost::make_shared(7, 11, 13)); + bayesNet.push_front(boost::make_shared(6, 11, 13)); + bayesNet.push_front(boost::make_shared(5, 10, 13)); + bayesNet.push_front(boost::make_shared(4, 10, 13)); + + bayesNet.push_front(boost::make_shared(3, 9, 12)); + bayesNet.push_front(boost::make_shared(2, 9, 12)); + bayesNet.push_front(boost::make_shared(1, 8, 12)); + bayesNet.push_front(boost::make_shared(0, 8, 12)); + + if (debug) { + GTSAM_PRINT(bayesNet); + bayesNet.saveGraph("/tmp/symbolicBayesNet.dot"); + } + + // create a BayesTree out of a Bayes net + SymbolicBayesTree bayesTree(bayesNet); + if (debug) { + GTSAM_PRINT(bayesTree); + bayesTree.saveGraph("/tmp/symbolicBayesTree.dot"); + } + + SymbolicBayesTree::Clique::shared_ptr R = bayesTree.root(); + + { + // check shortcut P(S9||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[9]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + EXPECT(assert_equal(expected, shortcut)); + } + + { + // check shortcut P(S8||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[8]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(12, 14)); + EXPECT(assert_equal(expected, shortcut)); + } + + { + // check shortcut P(S4||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[4]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(10, 13, 14)); + EXPECT(assert_equal(expected, shortcut)); + } + + { + // check shortcut P(S0||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[0]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(12, 14)); + expected.push_front(boost::make_shared(8, 12, 14)); + EXPECT(assert_equal(expected, shortcut)); + } +} + +/* ************************************************************************* * + Bayes tree for smoother with "natural" ordering: + C1 5 6 + C2 4 : 5 + C3 3 : 4 + C4 2 : 3 + C5 1 : 2 + C6 0 : 1 + **************************************************************************** */ + +TEST_UNSAFE( SymbolicBayesTree, linear_smoother_shortcuts ) { + // Create smoother with 7 nodes + SymbolicFactorGraph smoother; + smoother.push_factor(0); + smoother.push_factor(0, 1); + smoother.push_factor(1, 2); + smoother.push_factor(2, 3); + smoother.push_factor(3, 4); + smoother.push_factor(4, 5); + smoother.push_factor(5, 6); + + BayesNet bayesNet = + *SymbolicSequentialSolver(smoother).eliminate(); + + if (debug) { + GTSAM_PRINT(bayesNet); + bayesNet.saveGraph("/tmp/symbolicBayesNet.dot"); + } + + // create a BayesTree out of a Bayes net + SymbolicBayesTree bayesTree(bayesNet); + if (debug) { + GTSAM_PRINT(bayesTree); + bayesTree.saveGraph("/tmp/symbolicBayesTree.dot"); + } + + SymbolicBayesTree::Clique::shared_ptr R = bayesTree.root(); + + { + // check shortcut P(S2||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[4]; // 4 is frontal in C2 + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + EXPECT(assert_equal(expected, shortcut)); + } + + { + // check shortcut P(S3||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[3]; // 3 is frontal in C3 + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(4, 5)); + EXPECT(assert_equal(expected, shortcut)); + } + + { + // check shortcut P(S4||R) to root + SymbolicBayesTree::Clique::shared_ptr c = bayesTree[2]; // 2 is frontal in C4 + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(3, 5)); + EXPECT(assert_equal(expected, shortcut)); + } +} + +/* ************************************************************************* */ +// from testSymbolicJunctionTree, which failed at one point +TEST(SymbolicBayesTree, complicatedMarginal) { + + // Create the conditionals to go in the BayesTree + list L; + L = list_of(1)(2)(5); + IndexConditional::shared_ptr R_1_2(new IndexConditional(L, 2)); + L = list_of(3)(4)(6); + IndexConditional::shared_ptr R_3_4(new IndexConditional(L, 2)); + L = list_of(5)(6)(7)(8); + IndexConditional::shared_ptr R_5_6(new IndexConditional(L, 2)); + L = list_of(7)(8)(11); + IndexConditional::shared_ptr R_7_8(new IndexConditional(L, 2)); + L = list_of(9)(10)(11)(12); + IndexConditional::shared_ptr R_9_10(new IndexConditional(L, 2)); + L = list_of(11)(12); + IndexConditional::shared_ptr R_11_12(new IndexConditional(L, 2)); + + // Symbolic Bayes Tree + typedef SymbolicBayesTree::Clique Clique; + typedef SymbolicBayesTree::sharedClique sharedClique; + + // Create Bayes Tree + SymbolicBayesTree bt; + bt.insert(sharedClique(new Clique(R_11_12))); + bt.insert(sharedClique(new Clique(R_9_10))); + bt.insert(sharedClique(new Clique(R_7_8))); + bt.insert(sharedClique(new Clique(R_5_6))); + bt.insert(sharedClique(new Clique(R_3_4))); + bt.insert(sharedClique(new Clique(R_1_2))); + if (debug) { + GTSAM_PRINT(bt); + bt.saveGraph("/tmp/symbolicBayesTree.dot"); + } + + SymbolicBayesTree::Clique::shared_ptr R = bt.root(); + SymbolicBayesNet empty; + + // Shortcut on 9 + { + SymbolicBayesTree::Clique::shared_ptr c = bt[9]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + EXPECT(assert_equal(empty, shortcut)); + } + + // Shortcut on 7 + { + SymbolicBayesTree::Clique::shared_ptr c = bt[7]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + EXPECT(assert_equal(empty, shortcut)); + } + + // Shortcut on 5 + { + SymbolicBayesTree::Clique::shared_ptr c = bt[5]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(8, 11)); + expected.push_front(boost::make_shared(7, 8, 11)); + EXPECT(assert_equal(expected, shortcut)); + } + + // Shortcut on 3 + { + SymbolicBayesTree::Clique::shared_ptr c = bt[3]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(6, 11)); + EXPECT(assert_equal(expected, shortcut)); + } + + // Shortcut on 1 + { + SymbolicBayesTree::Clique::shared_ptr c = bt[1]; + SymbolicBayesNet shortcut = c->shortcut(R, EliminateSymbolic); + SymbolicBayesNet expected; + expected.push_front(boost::make_shared(5, 11)); + EXPECT(assert_equal(expected, shortcut)); + } + + // Marginal on 5 + { + IndexFactor::shared_ptr actual = bt.marginalFactor(5, EliminateSymbolic); + EXPECT(assert_equal(IndexFactor(5), *actual, 1e-1)); + } + + // Shortcut on 6 + { + IndexFactor::shared_ptr actual = bt.marginalFactor(6, EliminateSymbolic); + EXPECT(assert_equal(IndexFactor(6), *actual, 1e-1)); + } + +} +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ +