diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 8c3e2e7b6..1eb428669 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -77,7 +77,7 @@ namespace gtsam { DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine // (size_t nrFrontals, ADT::Binary op) const { - if (nrFrontals == 0 || nrFrontals > size()) throw invalid_argument( + if (nrFrontals > size()) throw invalid_argument( (boost::format( "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") % nrFrontals % size()).str()); diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 2afc88e8a..faacea8a7 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -20,6 +20,7 @@ #pragma once +#include #include #include #include @@ -56,12 +57,6 @@ namespace gtsam { } } - /* ************************************************************************* */ - template - size_t BayesTree::numCachedShortcuts() const { - return (root_) ? root_->numCachedShortcuts() : 0; - } - /* ************************************************************************* */ template size_t BayesTree::numCachedSeparatorMarginals() const { @@ -564,25 +559,92 @@ namespace gtsam { template typename FactorGraph::shared_ptr BayesTree::joint(Index j1, Index j2, Eliminate function) const { + gttic(BayesTree_joint); -#ifdef SHORTCUT_JOINTS // get clique C1 and C2 sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; - // calculate joint - FactorGraph p_C1C2(C1->joint(C2, root_, function)); + gttic(Lowest_common_ancestor); + // Find lowest common ancestor clique + sharedClique B; { + // Build two paths to the root + FastList path1, path2; { + sharedClique p = C1; + while(p) { + path1.push_front(p); + p = p->parent(); + } + } { + sharedClique p = C2; + while(p) { + path2.push_front(p); + p = p->parent(); + } + } + // Find the path intersection + B = this->root(); + FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); + while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { + B = *p1; + ++p1; + ++p2; + } + } + gttoc(Lowest_common_ancestor); - // eliminate remaining factor graph to get requested joint - std::vector j12(2); j12[0] = j1; j12[1] = j2; - GenericSequentialSolver solver(p_C1C2); - return solver.jointFactorGraph(j12,function); -#else - std::vector indices(2); - indices[0] = j1; - indices[1] = j2; - GenericSequentialSolver solver(FactorGraph(*this)); - return solver.jointFactorGraph(indices, function); -#endif + // Compute marginal on lowest common ancestor clique + gttic(LCA_marginal); + FactorGraph p_B = B->marginal2(this->root(), function); + gttoc(LCA_marginal); + + // Compute shortcuts of the requested cliques given the lowest common ancestor + gttic(Clique_shortcuts); + BayesNet p_C1_Bred = C1->shortcut(B, function); + BayesNet p_C2_Bred = C2->shortcut(B, function); + gttoc(Clique_shortcuts); + + // Factor the shortcuts to be conditioned on the full root + // Get the set of variables to eliminate, which is C1\B. + gttic(Full_root_factoring); + sharedConditional p_C1_B; { + std::vector C1_minus_B; { + FastSet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); + BOOST_FOREACH(const Index j, *B->conditional()) { + C1_minus_B_set.erase(j); } + C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); + } + // Factor into C1\B | B. + FactorGraph temp_remaining; + boost::tie(p_C1_B, temp_remaining) = FactorGraph(p_C1_Bred).eliminate(C1_minus_B, function); + } + sharedConditional p_C2_B; { + std::vector C2_minus_B; { + FastSet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); + BOOST_FOREACH(const Index j, *B->conditional()) { + C2_minus_B_set.erase(j); } + C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); + } + // Factor into C2\B | B. + FactorGraph temp_remaining; + boost::tie(p_C2_B, temp_remaining) = FactorGraph(p_C2_Bred).eliminate(C2_minus_B, function); + } + gttoc(Full_root_factoring); + + gttic(Variable_joint); + // Build joint on all involved variables + FactorGraph p_BC1C2; + p_BC1C2.push_back(p_B); + p_BC1C2.push_back(p_C1_B->toFactor()); + p_BC1C2.push_back(p_C2_B->toFactor()); + if(C1 != B) + p_BC1C2.push_back(C1->conditional()->toFactor()); + if(C2 != B) + p_BC1C2.push_back(C2->conditional()->toFactor()); + + // Compute final marginal by eliminating other variables + GenericSequentialSolver solver(p_BC1C2); + std::vector js; js.push_back(j1); js.push_back(j2); + return solver.jointFactorGraph(js, function); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 26d896492..3982a5987 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -176,9 +176,6 @@ namespace gtsam { /** Gather data on all cliques */ CliqueData getCliqueData() const; - /** Collect number of cliques with cached shortcuts */ - size_t numCachedShortcuts() const; - /** Collect number of cliques with cached separator marginals */ size_t numCachedSeparatorMarginals() const; diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index 30b42b48f..c9d8e1f5c 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -99,19 +99,6 @@ namespace gtsam { return size; } - /* ************************************************************************* */ - template - size_t BayesTreeCliqueBase::numCachedShortcuts() const { - if (!cachedShortcut_) - return 0; - - size_t subtree_count = 1; - BOOST_FOREACH(const derived_ptr& child, children_) - subtree_count += child->numCachedShortcuts(); - - return subtree_count; - } - /* ************************************************************************* */ template size_t BayesTreeCliqueBase::numCachedSeparatorMarginals() const { @@ -178,111 +165,51 @@ namespace gtsam { derived_ptr B, Eliminate function) const { gttic(BayesTreeCliqueBase_shortcut); - // Check if the ShortCut already exists - if (!cachedShortcut_) { - gttic(BayesTreeCliqueBase_shortcut_cachemiss); - // We only calculate the shortcut when this clique is not B - // and when the S\B is not empty - std::vector S_setminus_B = separator_setminus_B(B); - if (B.get() != this && !S_setminus_B.empty()) { + // We only calculate the shortcut when this clique is not B + // and when the S\B is not empty + std::vector S_setminus_B = separator_setminus_B(B); + if (B.get() != this && !S_setminus_B.empty()) { - // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph - derived_ptr parent(parent_.lock()); - gttoc(BayesTreeCliqueBase_shortcut_cachemiss); - gttoc(BayesTreeCliqueBase_shortcut); - FactorGraph p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) - gttic(BayesTreeCliqueBase_shortcut); - gttic(BayesTreeCliqueBase_shortcut_cachemiss); - p_Cp_B.push_back(parent->conditional()->toFactor()); // P(Fp|Sp) + // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph + derived_ptr parent(parent_.lock()); + gttoc(BayesTreeCliqueBase_shortcut); + FactorGraph p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) + gttic(BayesTreeCliqueBase_shortcut); + p_Cp_B.push_back(parent->conditional()->toFactor()); // P(Fp|Sp) - // Determine the variables we want to keepSet, S union B - std::vector keep = shortcut_indices(B, p_Cp_B); + // Determine the variables we want to keepSet, S union B + std::vector keep = shortcut_indices(B, p_Cp_B); - // Reduce the variable indices to start at zero - gttic(Reduce); - const Permutation reduction = internal::createReducingPermutation(p_Cp_B.keys()); - internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction); - BOOST_FOREACH(const boost::shared_ptr& factor, p_Cp_B) { - if(factor) factor->reduceWithInverse(inverseReduction); } - inverseReduction.applyInverse(keep); - gttoc(Reduce); + // Reduce the variable indices to start at zero + gttic(Reduce); + const Permutation reduction = internal::createReducingPermutation(p_Cp_B.keys()); + internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction); + BOOST_FOREACH(const boost::shared_ptr& factor, p_Cp_B) { + if(factor) factor->reduceWithInverse(inverseReduction); } + inverseReduction.applyInverse(keep); + gttoc(Reduce); - // Create solver that will marginalize for us - GenericSequentialSolver solver(p_Cp_B); + // Create solver that will marginalize for us + GenericSequentialSolver solver(p_Cp_B); - // Finally, we only want to have S\B variables in the Bayes net, so - size_t nrFrontals = S_setminus_B.size(); - cachedShortcut_ = // - *solver.conditionalBayesNet(keep, nrFrontals, function); + // Finally, we only want to have S\B variables in the Bayes net, so + size_t nrFrontals = S_setminus_B.size(); + BayesNet result = *solver.conditionalBayesNet(keep, nrFrontals, function); - // Undo the reduction - gttic(Undo_Reduce); - BOOST_FOREACH(const typename boost::shared_ptr& factor, p_Cp_B) { - if (factor) factor->permuteWithInverse(reduction); } - cachedShortcut_->permuteWithInverse(reduction); - gttoc(Undo_Reduce); + // Undo the reduction + gttic(Undo_Reduce); + BOOST_FOREACH(const typename boost::shared_ptr& factor, p_Cp_B) { + if (factor) factor->permuteWithInverse(reduction); } + result.permuteWithInverse(reduction); + gttoc(Undo_Reduce); - assertInvariants(); - } else { - BayesNet empty; - cachedShortcut_ = empty; - } + assertInvariants(); + + return result; } else { - gttic(BayesTreeCliqueBase_shortcut_cachehit); + return BayesNet(); } - - // return the shortcut P(S||B) - return *cachedShortcut_; // return the cached version - } - - /* ************************************************************************* */ - // P(C) = \int_R P(F|S) P(S|R) P(R) - // TODO: Maybe we should integrate given parent marginal P(Cp), - // \int(Cp\S) P(F|S)P(S|Cp)P(Cp) - // Because the root clique could be very big. - /* ************************************************************************* */ - template - FactorGraph::FactorType> BayesTreeCliqueBase< - DERIVED, CONDITIONAL>::marginal(derived_ptr R, Eliminate function) const - { - gttic(BayesTreeCliqueBase_marginal); - // If we are the root, just return this root - // NOTE: immediately cast to a factor graph - BayesNet bn(R->conditional()); - if (R.get() == this) - return bn; - - // Combine P(F|S), P(S|R), and P(R) - BayesNet p_FSRc = this->shortcut(R, function); - p_FSRc.push_front(this->conditional()); - p_FSRc.push_back(R->conditional()); - FactorGraph p_FSR = p_FSRc; - - assertInvariants(); - - // Reduce the variable indices to start at zero - gttic(Reduce); - const Permutation reduction = internal::createReducingPermutation(p_FSR.keys()); - internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction); - BOOST_FOREACH(const boost::shared_ptr& factor, p_FSR) { - factor->reduceWithInverse(inverseReduction); } - std::vector keysFS = conditional_->keys(); - inverseReduction.applyInverse(keysFS); - gttoc(Reduce); - - // Eliminate to get the marginal - const GenericSequentialSolver solver(p_FSR); - FactorGraph::FactorType> result = - *solver.jointFactorGraph(keysFS, function); - - // Undo the reduction (don't need to undo p_FSR since the FactorGraph conversion no longer references the cached shortcuts) - gttic(Undo_Reduce); - BOOST_FOREACH(const typename boost::shared_ptr& factor, result) { - if (factor) factor->permuteWithInverse(reduction); } - gttoc(Undo_Reduce); - - return result; } /* ************************************************************************* */ @@ -364,53 +291,6 @@ namespace gtsam { return p_C; } -#ifdef SHORTCUT_JOINTS - /* ************************************************************************* */ - // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R) - /* ************************************************************************* */ - template - FactorGraph::FactorType> BayesTreeCliqueBase< - DERIVED, CONDITIONAL>::joint(derived_ptr C2, derived_ptr R, - Eliminate function) const - { - gttic(BayesTreeCliqueBase_joint); - // For now, assume neither is the root - - sharedConditional p_F1_S1 = this->conditional(); - sharedConditional p_F2_S2 = C2->conditional(); - - // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R) - FactorGraph joint; - if (!isRoot()) { - joint.push_back(p_F1_S1->toFactor()); // P(F1|S1) - joint.push_back(shortcut(R, function)); // P(S1|R) - } - if (!C2->isRoot()) { - joint.push_back(p_F2_S2->toFactor()); // P(F2|S2) - joint.push_back(C2->shortcut(R, function)); // P(S2|R) - } - joint.push_back(R->conditional()->toFactor()); // P(R) - - // Merge the keys of C1 and C2 - std::vector keys12; - std::vector &indices1 = p_F1_S1->keys(), &indices2 = p_F2_S2->keys(); - std::set_union(indices1.begin(), indices1.end(), // - indices2.begin(), indices2.end(), std::back_inserter(keys12)); - - // Check validity - bool cliques_intersect = (keys12.size() < indices1.size() + indices2.size()); - if (!isRoot() && !C2->isRoot() && cliques_intersect) - throw std::runtime_error( - "BayesTreeCliqueBase::joint can only calculate joint if cliques are disjoint\n" - "or one of them is the root clique"); - - // Calculate the marginal - assertInvariants(); - GenericSequentialSolver solver(joint); - return *solver.jointFactorGraph(keys12, function); - } -#endif - /* ************************************************************************* */ template void BayesTreeCliqueBase::deleteCachedShortcuts() { @@ -418,13 +298,13 @@ namespace gtsam { // When a shortcut is requested, all of the shortcuts between it and the // root are also generated. So, if this clique's cached shortcut is set, // recursively call over all child cliques. Otherwise, it is unnecessary. - if (cachedShortcut_) { + if (cachedSeparatorMarginal_) { BOOST_FOREACH(derived_ptr& child, children_) { child->deleteCachedShortcuts(); } //Delete CachedShortcut for this clique - this->resetCachedShortcut(); + cachedSeparatorMarginal_ = boost::none; } } diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index 0a47d6513..487d10a0d 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -79,9 +79,6 @@ namespace gtsam { /// @} - /// This stores the Cached Shortcut value - mutable boost::optional > cachedShortcut_; - /// This stores the Cached separator margnal P(S) mutable boost::optional > cachedSeparatorMarginal_; @@ -124,9 +121,6 @@ namespace gtsam { /** The size of subtree rooted at this clique, i.e., nr of Cliques */ size_t treeSize() const; - /** Collect number of cliques with cached shortcuts in subtree */ - size_t numCachedShortcuts() const; - /** Collect number of cliques with cached separator marginals */ size_t numCachedSeparatorMarginals() const; @@ -194,34 +188,18 @@ namespace gtsam { /** return the conditional P(S|Root) on the separator given the root */ BayesNet shortcut(derived_ptr root, Eliminate function) const; - /** return the marginal P(C) of the clique */ - FactorGraph marginal(derived_ptr root, Eliminate function) const; - /** return the marginal P(S) on the separator */ FactorGraph separatorMarginal(derived_ptr root, Eliminate function) const; /** return the marginal P(C) of the clique, using marginal caching */ FactorGraph marginal2(derived_ptr root, Eliminate function) const; -#ifdef SHORTCUT_JOINTS - /** - * return the joint P(C1,C2), where C1==this. TODO: not a method? - * Limitation: can only calculate joint if cliques are disjoint or one of them is root - */ - FactorGraph joint(derived_ptr C2, derived_ptr root, Eliminate function) const; -#endif - /** * This deletes the cached shortcuts of all cliques (subtree) below this clique. * This is performed when the bayes tree is modified. */ void deleteCachedShortcuts(); - /** return cached shortcut of the clique */ - const boost::optional >& cachedShortcut() const { - return cachedShortcut_; - } - const boost::optional >& cachedSeparatorMarginal() const { return cachedSeparatorMarginal_; } @@ -247,12 +225,6 @@ namespace gtsam { std::vector shortcut_indices(derived_ptr B, const FactorGraph& p_Cp_B) const; - /// Reset the computed shortcut of this clique. Used by friend BayesTree - void resetCachedShortcut() { - cachedSeparatorMarginal_ = boost::none; - cachedShortcut_ = boost::none; - } - private: /** diff --git a/gtsam/inference/EliminationTree-inl.h b/gtsam/inference/EliminationTree-inl.h index b63699fa5..d85e2e651 100644 --- a/gtsam/inference/EliminationTree-inl.h +++ b/gtsam/inference/EliminationTree-inl.h @@ -141,11 +141,8 @@ typename EliminationTree::shared_ptr EliminationTree::Create( // Hang factors in right places gttic(hang_factors); - BOOST_FOREACH(const typename boost::shared_ptr& derivedFactor, factorGraph) { - // Here we upwards-cast to the factor type of this EliminationTree. This - // allows performing symbolic elimination on, for example, GaussianFactors. - if(derivedFactor) { - sharedFactor factor(derivedFactor); + BOOST_FOREACH(const typename boost::shared_ptr& factor, factorGraph) { + if(factor && factor->size() > 0) { Index j = *std::min_element(factor->begin(), factor->end()); if(j < structure.size()) trees[j]->add(factor); diff --git a/gtsam/inference/SymbolicFactorGraph.cpp b/gtsam/inference/SymbolicFactorGraph.cpp index 4a3835696..05678f26f 100644 --- a/gtsam/inference/SymbolicFactorGraph.cpp +++ b/gtsam/inference/SymbolicFactorGraph.cpp @@ -120,8 +120,8 @@ namespace gtsam { BOOST_FOREACH(Index var, *factor) keys.insert(var); - if (keys.size() < 1) throw invalid_argument( - "IndexFactor::CombineAndEliminate called on factors with no variables."); + if (keys.size() < nrFrontals) throw invalid_argument( + "EliminateSymbolic requested to eliminate more variables than exist in graph."); vector newKeys(keys.begin(), keys.end()); return make_pair(boost::make_shared(newKeys, nrFrontals), diff --git a/gtsam/inference/tests/testBayesTree.cpp b/gtsam/inference/tests/testBayesTree.cpp index 3d3656aa4..bbc02e07c 100644 --- a/gtsam/inference/tests/testBayesTree.cpp +++ b/gtsam/inference/tests/testBayesTree.cpp @@ -290,10 +290,9 @@ TEST( BayesTree, shortcutCheck ) // Check if all the cached shortcuts are cleared rootClique->deleteCachedShortcuts(); BOOST_FOREACH(SymbolicBayesTree::sharedClique& clique, allCliques) { - bool notCleared = clique->cachedShortcut(); + bool notCleared = clique->cachedSeparatorMarginal(); CHECK( notCleared == false); } - EXPECT_LONGS_EQUAL(0, rootClique->numCachedShortcuts()); EXPECT_LONGS_EQUAL(0, rootClique->numCachedSeparatorMarginals()); // BOOST_FOREACH(SymbolicBayesTree::sharedClique& clique, allCliques) { diff --git a/tests/testGaussianBayesTree.cpp b/tests/testGaussianBayesTree.cpp index f1fbf7075..85ea792fe 100644 --- a/tests/testGaussianBayesTree.cpp +++ b/tests/testGaussianBayesTree.cpp @@ -321,6 +321,45 @@ TEST(GaussianBayesTree, simpleMarginal) EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(GaussianBayesTree, shortcut_overlapping_separator) +{ + // Test computing shortcuts when the separator overlaps. This previously + // would have highlighted a problem where information was duplicated. + + // Create factor graph: + // f(1,2,5) + // f(3,4,5) + // f(5,6) + // f(6,7) + GaussianFactorGraph fg; + noiseModel::Diagonal::shared_ptr model = noiseModel::Unit::Create(1); + fg.add(1, Matrix_(1,1, 1.0), 3, Matrix_(1,1, 2.0), 5, Matrix_(1,1, 3.0), Vector_(1, 4.0), model); + fg.add(1, Matrix_(1,1, 5.0), Vector_(1, 6.0), model); + fg.add(2, Matrix_(1,1, 7.0), 4, Matrix_(1,1, 8.0), 5, Matrix_(1,1, 9.0), Vector_(1, 10.0), model); + fg.add(2, Matrix_(1,1, 11.0), Vector_(1, 12.0), model); + fg.add(5, Matrix_(1,1, 13.0), 6, Matrix_(1,1, 14.0), Vector_(1, 15.0), model); + fg.add(6, Matrix_(1,1, 17.0), 7, Matrix_(1,1, 18.0), Vector_(1, 19.0), model); + fg.add(7, Matrix_(1,1, 20.0), Vector_(1, 21.0), model); + + // Eliminate into BayesTree + // c(6,7) + // c(5|6) + // c(1,2|5) + // c(3,4|5) + GaussianBayesTree bt = *GaussianMultifrontalSolver(fg).eliminate(); + + GaussianFactorGraph joint = *bt.joint(1,2, EliminateQR); + + Matrix expectedJointJ = (Matrix(2,3) << + 0, 11, 12, + -5, 0, -6 + ).finished(); + Matrix actualJointJ = joint.augmentedJacobian(); + + EXPECT(assert_equal(expectedJointJ, actualJointJ)); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */