From bfc033d3f76c861e13be53a813956733053e2e1e Mon Sep 17 00:00:00 2001 From: Abhijit Kundu Date: Sun, 24 Jun 2012 04:13:49 +0000 Subject: [PATCH] Minor change to shortcut --- gtsam/inference/BayesTreeCliqueBase-inl.h | 178 +++++++++++----------- 1 file changed, 88 insertions(+), 90 deletions(-) diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index 272747126..c5851a5e4 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -102,114 +102,112 @@ namespace gtsam { return changed; } - /* ************************************************************************* */ - // The shortcut density is a conditional P(S|R) of the separator of this - // clique on the root. We can compute it recursively from the parent shortcut - // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p - /* ************************************************************************* */ - template - BayesNet BayesTreeCliqueBase::shortcut( - derived_ptr R, Eliminate function) const{ + /* ************************************************************************* */ + // The shortcut density is a conditional P(S|R) of the separator of this + // clique on the root. We can compute it recursively from the parent shortcut + // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p + /* ************************************************************************* */ + template + BayesNet BayesTreeCliqueBase::shortcut( + derived_ptr R, Eliminate function) const{ - static const bool debug = false; + static const bool debug = false; - BayesNet p_S_R; //shortcut P(S|R) + BayesNet p_S_R; //shortcut P(S|R) This is empty now - //Check if the ShortCut already exists - if(!cachedShortcut_){ + //Check if the ShortCut already exists + if(!cachedShortcut_){ - // A first base case is when this clique or its parent is the root, - // in which case we return an empty Bayes net. + // A first base case is when this clique or its parent is the root, + // in which case we return an empty Bayes net. - derived_ptr parent(parent_.lock()); + derived_ptr parent(parent_.lock()); + if (R.get() != this && parent != R) { - if (R.get() == this || parent == R) { - BayesNet empty; - return empty; - } + // The root conditional + FactorGraph p_R(BayesNet(R->conditional())); - // The root conditional - FactorGraph p_R(BayesNet(R->conditional())); + // The parent clique has a ConditionalType for each frontal node in Fp + // so we can obtain P(Fp|Sp) in factor graph form + FactorGraph p_Fp_Sp(BayesNet(parent->conditional())); - // The parent clique has a ConditionalType for each frontal node in Fp - // so we can obtain P(Fp|Sp) in factor graph form - FactorGraph p_Fp_Sp(BayesNet(parent->conditional())); + // If not the base case, obtain the parent shortcut P(Sp|R) as factors + FactorGraph p_Sp_R(parent->shortcut(R, function)); - // If not the base case, obtain the parent shortcut P(Sp|R) as factors - FactorGraph p_Sp_R(parent->shortcut(R, function)); + // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) + FactorGraph p_Cp_R; + p_Cp_R.push_back(p_R); + p_Cp_R.push_back(p_Fp_Sp); + p_Cp_R.push_back(p_Sp_R); - // now combine P(Cp|R) = P(Fp|Sp) * P(Sp|R) - FactorGraph p_Cp_R; - p_Cp_R.push_back(p_R); - p_Cp_R.push_back(p_Fp_Sp); - p_Cp_R.push_back(p_Sp_R); + // Eliminate into a Bayes net with ordering designed to integrate out + // any variables not in *our* separator. Variables to integrate out must be + // eliminated first hence the desired ordering is [Cp\S S]. + // However, an added wrinkle is that Cp might overlap with the root. + // Keys corresponding to the root should not be added to the ordering at all. - // Eliminate into a Bayes net with ordering designed to integrate out - // any variables not in *our* separator. Variables to integrate out must be - // eliminated first hence the desired ordering is [Cp\S S]. - // However, an added wrinkle is that Cp might overlap with the root. - // Keys corresponding to the root should not be added to the ordering at all. + if(debug) { + p_R.print("p_R: "); + p_Fp_Sp.print("p_Fp_Sp: "); + p_Sp_R.print("p_Sp_R: "); + } - if(debug) { - p_R.print("p_R: "); - p_Fp_Sp.print("p_Fp_Sp: "); - p_Sp_R.print("p_Sp_R: "); - } + // We want to factor into a conditional of the clique variables given the + // root and the marginal on the root, integrating out all other variables. + // The integrands include any parents of this clique and the variables of + // the parent clique. + FastSet variablesAtBack; + FastSet separator; + size_t uniqueRootVariables = 0; + BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) { + variablesAtBack.insert(separatorIndex); + separator.insert(separatorIndex); + if(debug) std::cout << "At back (this): " << separatorIndex << std::endl; + } + BOOST_FOREACH(const Index key, R->conditional()->keys()) { + if(variablesAtBack.insert(key).second) + ++ uniqueRootVariables; + if(debug) std::cout << "At back (root): " << key << std::endl; + } - // We want to factor into a conditional of the clique variables given the - // root and the marginal on the root, integrating out all other variables. - // The integrands include any parents of this clique and the variables of - // the parent clique. - FastSet variablesAtBack; - FastSet separator; - size_t uniqueRootVariables = 0; - BOOST_FOREACH(const Index separatorIndex, this->conditional()->parents()) { - variablesAtBack.insert(separatorIndex); - separator.insert(separatorIndex); - if(debug) std::cout << "At back (this): " << separatorIndex << std::endl; - } - BOOST_FOREACH(const Index key, R->conditional()->keys()) { - if(variablesAtBack.insert(key).second) - ++ uniqueRootVariables; - if(debug) std::cout << "At back (root): " << key << std::endl; - } + Permutation toBack = Permutation::PushToBack( + std::vector(variablesAtBack.begin(), variablesAtBack.end()), + R->conditional()->lastFrontalKey() + 1); + Permutation::shared_ptr toBackInverse(toBack.inverse()); + BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) { + factor->permuteWithInverse(*toBackInverse); } + typename BayesNet::shared_ptr eliminated(EliminationTree< + FactorType>::Create(p_Cp_R)->eliminate(function)); - Permutation toBack = Permutation::PushToBack( - std::vector(variablesAtBack.begin(), variablesAtBack.end()), - R->conditional()->lastFrontalKey() + 1); - Permutation::shared_ptr toBackInverse(toBack.inverse()); - BOOST_FOREACH(const typename FactorType::shared_ptr& factor, p_Cp_R) { - factor->permuteWithInverse(*toBackInverse); } - typename BayesNet::shared_ptr eliminated(EliminationTree< - FactorType>::Create(p_Cp_R)->eliminate(function)); + // Take only the conditionals for p(S|R). We check for each variable being + // in the separator set because if some separator variables overlap with + // root variables, we cannot rely on the number of root variables, and also + // want to include those variables in the conditional. + BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) { + assert(conditional->nrFrontals() == 1); + if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) { + if(debug) + conditional->print("Taking C|R conditional: "); + p_S_R.push_front(conditional); + } + if(p_S_R.size() == separator.size()) + break; + } - // Take only the conditionals for p(S|R). We check for each variable being - // in the separator set because if some separator variables overlap with - // root variables, we cannot rely on the number of root variables, and also - // want to include those variables in the conditional. - BOOST_REVERSE_FOREACH(typename ConditionalType::shared_ptr conditional, *eliminated) { - assert(conditional->nrFrontals() == 1); - if(separator.find(toBack[conditional->firstFrontalKey()]) != separator.end()) { - if(debug) - conditional->print("Taking C|R conditional: "); - p_S_R.push_front(conditional); - } - if(p_S_R.size() == separator.size()) - break; - } + // Undo the permutation + if(debug) toBack.print("toBack: "); + p_S_R.permuteWithInverse(toBack); + } - // Undo the permutation - if(debug) toBack.print("toBack: "); - p_S_R.permuteWithInverse(toBack); + cachedShortcut_ = p_S_R; + } + else + p_S_R = *cachedShortcut_; // return the cached version - assertInvariants(); - cachedShortcut_ = p_S_R; - } - else - p_S_R = *cachedShortcut_; + assertInvariants(); - // return the shortcut P(S|R) - return p_S_R; + // return the shortcut P(S|R) + return p_S_R; } /* ************************************************************************* */