From 835d1d6b50d2428ab30b81b2cfe73dee53ea07e5 Mon Sep 17 00:00:00 2001 From: Abhijit Kundu Date: Thu, 21 Jun 2012 22:32:28 +0000 Subject: [PATCH] First Iteration of Shortcut Cache changes and misc const fixes --- gtsam/inference/BayesTree-inl.h | 21 ++- gtsam/inference/BayesTree.h | 6 + gtsam/inference/BayesTreeCliqueBase-inl.h | 178 ++++++++++++---------- gtsam/inference/BayesTreeCliqueBase.h | 13 +- 4 files changed, 128 insertions(+), 90 deletions(-) diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 9fb8080ce..dfefd721e 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -557,10 +557,27 @@ namespace gtsam { /* ************************************************************************* */ template - template - void BayesTree::removeTop(const CONTAINER& keys, + void BayesTree::deleteCachedShorcuts(const sharedClique& subtree) { + // Check if subtree exists + if (subtree) { + //Delete CachedShortcut for this clique + subtree->resetCachedShortcut(); + // Recursive call over all child cliques + BOOST_FOREACH(sharedClique& childClique, subtree->children()) { + deleteCachedShorcuts(childClique); + } + } + } + + /* ************************************************************************* */ + template + template + void BayesTree::removeTop(const CONTAINER& keys, BayesNet& bn, typename BayesTree::Cliques& orphans) { + //TODO: Improve this + deleteCachedShorcuts(this->root_); + // process each key of the new factor BOOST_FOREACH(const Index& key, keys) { diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index de9a84640..14b01ab03 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -280,6 +280,12 @@ namespace gtsam { sharedClique insert(const sharedConditional& clique, std::list& children, bool isRootClique = false); + /** + * This deletes the cached shortcuts of all cliques in a subtree. This is + * performed when the bayes tree is modified. + */ + void deleteCachedShorcuts(const sharedClique& subtree); + private: /** deep copy to another tree */ diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index c57d6ed50..272747126 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -102,103 +102,113 @@ namespace gtsam { return changed; } - /* ************************************************************************* */ - // The shortcut density is a conditional P(S|R) of the separator of this - // clique on the root. We can compute it recursively from the parent shortcut - // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p - /* ************************************************************************* */ - template - BayesNet BayesTreeCliqueBase::shortcut(derived_ptr R, Eliminate function) { + /* ************************************************************************* */ + // The shortcut density is a conditional P(S|R) of the separator of this + // clique on the root. We can compute it recursively from the parent shortcut + // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p + /* ************************************************************************* */ + template + BayesNet BayesTreeCliqueBase::shortcut( + derived_ptr R, Eliminate function) const{ - static const bool debug = false; + static const bool debug = false; - // A first base case is when this clique or its parent is the root, - // in which case we return an empty Bayes net. + BayesNet p_S_R; //shortcut P(S|R) - derived_ptr parent(parent_.lock()); + //Check if the ShortCut already exists + if(!cachedShortcut_){ - if (R.get()==this || parent==R) { - BayesNet empty; - return empty; - } + // A first base case is when this clique or its parent is the root, + // in which case we return an empty Bayes net. - // The root conditional - FactorGraph p_R(BayesNet(R->conditional())); + derived_ptr parent(parent_.lock()); - // The parent clique has a ConditionalType for each frontal node in Fp - // so we can obtain P(Fp|Sp) in factor graph form - FactorGraph p_Fp_Sp(BayesNet(parent->conditional())); + if (R.get() == this || parent == R) { + BayesNet empty; + return empty; + } - // If not the base case, obtain the parent shortcut P(Sp|R) as factors - FactorGraph p_Sp_R(parent->shortcut(R, function)); + // The root conditional + FactorGraph p_R(BayesNet(R->conditional())); - // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) - FactorGraph p_Cp_R; - p_Cp_R.push_back(p_R); - p_Cp_R.push_back(p_Fp_Sp); - p_Cp_R.push_back(p_Sp_R); + // The parent clique has a ConditionalType for each frontal node in Fp + // so we can obtain P(Fp|Sp) in factor graph form + FactorGraph p_Fp_Sp(BayesNet(parent->conditional())); - // Eliminate into a Bayes net with ordering designed to integrate out - // any variables not in *our* separator. Variables to integrate out must be - // eliminated first hence the desired ordering is [Cp\S S]. - // However, an added wrinkle is that Cp might overlap with the root. - // Keys corresponding to the root should not be added to the ordering at all. + // If not the base case, obtain the parent shortcut P(Sp|R) as factors + FactorGraph p_Sp_R(parent->shortcut(R, function)); - if(debug) { - p_R.print("p_R: "); - p_Fp_Sp.print("p_Fp_Sp: "); - p_Sp_R.print("p_Sp_R: "); - } + // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) + FactorGraph p_Cp_R; + p_Cp_R.push_back(p_R); + p_Cp_R.push_back(p_Fp_Sp); + p_Cp_R.push_back(p_Sp_R); - // We want to factor into a conditional of the clique variables given the - // root and the marginal on the root, integrating out all other variables. - // The integrands include any parents of this clique and the variables of - // the parent clique. - FastSet variablesAtBack; - FastSet separator; - size_t uniqueRootVariables = 0; - BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) { - variablesAtBack.insert(separatorIndex); - separator.insert(separatorIndex); - if(debug) std::cout << "At back (this): " << separatorIndex << std::endl; - } - BOOST_FOREACH(const Index key, R->conditional()->keys()) { - if(variablesAtBack.insert(key).second) - ++ uniqueRootVariables; - if(debug) std::cout << "At back (root): " << key << std::endl; - } + // Eliminate into a Bayes net with ordering designed to integrate out + // any variables not in *our* separator. Variables to integrate out must be + // eliminated first hence the desired ordering is [Cp\S S]. + // However, an added wrinkle is that Cp might overlap with the root. + // Keys corresponding to the root should not be added to the ordering at all. - Permutation toBack = Permutation::PushToBack( - std::vector(variablesAtBack.begin(), variablesAtBack.end()), - R->conditional()->lastFrontalKey() + 1); - Permutation::shared_ptr toBackInverse(toBack.inverse()); - BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) { - factor->permuteWithInverse(*toBackInverse); } - typename BayesNet::shared_ptr eliminated(EliminationTree< - FactorType>::Create(p_Cp_R)->eliminate(function)); + if(debug) { + p_R.print("p_R: "); + p_Fp_Sp.print("p_Fp_Sp: "); + p_Sp_R.print("p_Sp_R: "); + } - // Take only the conditionals for p(S|R). We check for each variable being - // in the separator set because if some separator variables overlap with - // root variables, we cannot rely on the number of root variables, and also - // want to include those variables in the conditional. - BayesNet p_S_R; - BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) { - assert(conditional->nrFrontals() == 1); - if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) { - if(debug) - conditional->print("Taking C|R conditional: "); - p_S_R.push_front(conditional); - } - if(p_S_R.size() == separator.size()) - break; - } + // We want to factor into a conditional of the clique variables given the + // root and the marginal on the root, integrating out all other variables. + // The integrands include any parents of this clique and the variables of + // the parent clique. + FastSet variablesAtBack; + FastSet separator; + size_t uniqueRootVariables = 0; + BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) { + variablesAtBack.insert(separatorIndex); + separator.insert(separatorIndex); + if(debug) std::cout << "At back (this): " << separatorIndex << std::endl; + } + BOOST_FOREACH(const Index key, R->conditional()->keys()) { + if(variablesAtBack.insert(key).second) + ++ uniqueRootVariables; + if(debug) std::cout << "At back (root): " << key << std::endl; + } - // Undo the permutation - if(debug) toBack.print("toBack: "); - p_S_R.permuteWithInverse(toBack); + Permutation toBack = Permutation::PushToBack( + std::vector(variablesAtBack.begin(), variablesAtBack.end()), + R->conditional()->lastFrontalKey() + 1); + Permutation::shared_ptr toBackInverse(toBack.inverse()); + BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) { + factor->permuteWithInverse(*toBackInverse); } + typename BayesNet::shared_ptr eliminated(EliminationTree< + FactorType>::Create(p_Cp_R)->eliminate(function)); - // return the parent shortcut P(Sp|R) - assertInvariants(); + // Take only the conditionals for p(S|R). We check for each variable being + // in the separator set because if some separator variables overlap with + // root variables, we cannot rely on the number of root variables, and also + // want to include those variables in the conditional. + BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) { + assert(conditional->nrFrontals() == 1); + if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) { + if(debug) + conditional->print("Taking C|R conditional: "); + p_S_R.push_front(conditional); + } + if(p_S_R.size() == separator.size()) + break; + } + + // Undo the permutation + if(debug) toBack.print("toBack: "); + p_S_R.permuteWithInverse(toBack); + + assertInvariants(); + cachedShortcut_ = p_S_R; + } + else + p_S_R = *cachedShortcut_; + + // return the shortcut P(S|R) return p_S_R; } @@ -210,7 +220,7 @@ namespace gtsam { /* ************************************************************************* */ template FactorGraph::FactorType> BayesTreeCliqueBase::marginal( - derived_ptr R, Eliminate function) { + derived_ptr R, Eliminate function) const{ // If we are the root, just return this root // NOTE: immediately cast to a factor graph BayesNet bn(R->conditional()); @@ -231,7 +241,7 @@ namespace gtsam { /* ************************************************************************* */ template FactorGraph::FactorType> BayesTreeCliqueBase::joint( - derived_ptr C2, derived_ptr R, Eliminate function) { + derived_ptr C2, derived_ptr R, Eliminate function) const { // For now, assume neither is the root // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index e09454e86..fa0992fbd 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -80,6 +80,9 @@ namespace gtsam { derived_weak_ptr parent_; std::list children_; + /// This stores the Cached Shortcut value + mutable boost::optional > cachedShortcut_; + /// @name Testable /// @{ @@ -150,14 +153,13 @@ namespace gtsam { bool permuteSeparatorWithInverse(const Permutation& inversePermutation); /** return the conditional P(S|Root) on the separator given the root */ - // TODO: create a cached version - BayesNet shortcut(derived_ptr root, Eliminate function); + BayesNet shortcut(derived_ptr root, Eliminate function) const; /** return the marginal P(C) of the clique */ - FactorGraph marginal(derived_ptr root, Eliminate function); + FactorGraph marginal(derived_ptr root, Eliminate function) const; /** return the joint P(C1,C2), where C1==this. TODO: not a method? */ - FactorGraph joint(derived_ptr C2, derived_ptr root, Eliminate function); + FactorGraph joint(derived_ptr C2, derived_ptr root, Eliminate function) const; friend class BayesTree; @@ -166,6 +168,9 @@ namespace gtsam { ///TODO: comment void assertInvariants() const; + /// Reset the computed shortcut of this clique. Used by friend BayesTree + void resetCachedShortcut() { cachedShortcut_ = boost::none; } + private: /** Cliques cannot be copied except by the clone() method, which does not