Joint marginals using lowest-common-ancestor shortcuts. As part of this commit, caching of shortcuts is removed, the BayesTreeCliqueBase::marginal function computing single-variable shortcut marginals is removed, and the factor/frontal size checks in symbolic and discrete elimination are modified to permit eliminating empty factors or zero frontal variables.
							parent
							
								
									279738c56f
								
							
						
					
					
						commit
						4d4e17c2a7
					
				| 
						 | 
				
			
			@ -77,7 +77,7 @@ namespace gtsam {
 | 
			
		|||
  DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine //
 | 
			
		||||
  (size_t nrFrontals, ADT::Binary op) const {
 | 
			
		||||
 | 
			
		||||
    if (nrFrontals == 0 || nrFrontals > size()) throw invalid_argument(
 | 
			
		||||
    if (nrFrontals > size()) throw invalid_argument(
 | 
			
		||||
        (boost::format(
 | 
			
		||||
            "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
 | 
			
		||||
            % nrFrontals % size()).str());
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@
 | 
			
		|||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <gtsam/base/FastList.h>
 | 
			
		||||
#include <gtsam/base/FastSet.h>
 | 
			
		||||
#include <gtsam/base/FastVector.h>
 | 
			
		||||
#include <gtsam/inference/BayesTree.h>
 | 
			
		||||
| 
						 | 
				
			
			@ -56,12 +57,6 @@ namespace gtsam {
 | 
			
		|||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class CONDITIONAL, class CLIQUE>
 | 
			
		||||
  size_t BayesTree<CONDITIONAL,CLIQUE>::numCachedShortcuts() const {
 | 
			
		||||
    return (root_) ? root_->numCachedShortcuts() : 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class CONDITIONAL, class CLIQUE>
 | 
			
		||||
  size_t BayesTree<CONDITIONAL,CLIQUE>::numCachedSeparatorMarginals() const {
 | 
			
		||||
| 
						 | 
				
			
			@ -564,25 +559,92 @@ namespace gtsam {
 | 
			
		|||
  template<class CONDITIONAL, class CLIQUE>
 | 
			
		||||
  typename FactorGraph<typename CONDITIONAL::FactorType>::shared_ptr
 | 
			
		||||
  BayesTree<CONDITIONAL,CLIQUE>::joint(Index j1, Index j2, Eliminate function) const {
 | 
			
		||||
    gttic(BayesTree_joint);
 | 
			
		||||
 | 
			
		||||
#ifdef SHORTCUT_JOINTS
 | 
			
		||||
    // get clique C1 and C2
 | 
			
		||||
    sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
 | 
			
		||||
 | 
			
		||||
    // calculate joint
 | 
			
		||||
    FactorGraph<FactorType> p_C1C2(C1->joint(C2, root_, function));
 | 
			
		||||
    gttic(Lowest_common_ancestor);
 | 
			
		||||
    // Find lowest common ancestor clique
 | 
			
		||||
    sharedClique B; {
 | 
			
		||||
      // Build two paths to the root
 | 
			
		||||
      FastList<sharedClique> path1, path2; {
 | 
			
		||||
        sharedClique p = C1;
 | 
			
		||||
        while(p) {
 | 
			
		||||
          path1.push_front(p);
 | 
			
		||||
          p = p->parent();
 | 
			
		||||
        }
 | 
			
		||||
      } {
 | 
			
		||||
        sharedClique p = C2;
 | 
			
		||||
        while(p) {
 | 
			
		||||
          path2.push_front(p);
 | 
			
		||||
          p = p->parent();
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      // Find the path intersection
 | 
			
		||||
      B = this->root();
 | 
			
		||||
      FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
 | 
			
		||||
      while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
 | 
			
		||||
        B = *p1;
 | 
			
		||||
        ++p1;
 | 
			
		||||
        ++p2;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    gttoc(Lowest_common_ancestor);
 | 
			
		||||
 | 
			
		||||
    // eliminate remaining factor graph to get requested joint
 | 
			
		||||
    std::vector<Index> j12(2); j12[0] = j1; j12[1] = j2;
 | 
			
		||||
    GenericSequentialSolver<FactorType> solver(p_C1C2);
 | 
			
		||||
    return solver.jointFactorGraph(j12,function);
 | 
			
		||||
#else
 | 
			
		||||
    std::vector<Index> indices(2);
 | 
			
		||||
    indices[0] = j1;
 | 
			
		||||
    indices[1] = j2;
 | 
			
		||||
    GenericSequentialSolver<FactorType> solver(FactorGraph<FactorType>(*this));
 | 
			
		||||
    return solver.jointFactorGraph(indices, function);
 | 
			
		||||
#endif
 | 
			
		||||
    // Compute marginal on lowest common ancestor clique
 | 
			
		||||
    gttic(LCA_marginal);
 | 
			
		||||
    FactorGraph<FactorType> p_B = B->marginal2(this->root(), function);
 | 
			
		||||
    gttoc(LCA_marginal);
 | 
			
		||||
 | 
			
		||||
    // Compute shortcuts of the requested cliques given the lowest common ancestor
 | 
			
		||||
    gttic(Clique_shortcuts);
 | 
			
		||||
    BayesNet<CONDITIONAL> p_C1_Bred = C1->shortcut(B, function);
 | 
			
		||||
    BayesNet<CONDITIONAL> p_C2_Bred = C2->shortcut(B, function);
 | 
			
		||||
    gttoc(Clique_shortcuts);
 | 
			
		||||
 | 
			
		||||
    // Factor the shortcuts to be conditioned on the full root
 | 
			
		||||
    // Get the set of variables to eliminate, which is C1\B.
 | 
			
		||||
    gttic(Full_root_factoring);
 | 
			
		||||
    sharedConditional p_C1_B; {
 | 
			
		||||
      std::vector<Index> C1_minus_B; {
 | 
			
		||||
        FastSet<Index> C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
 | 
			
		||||
        BOOST_FOREACH(const Index j, *B->conditional()) {
 | 
			
		||||
          C1_minus_B_set.erase(j); }
 | 
			
		||||
        C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
 | 
			
		||||
      }
 | 
			
		||||
      // Factor into C1\B | B.
 | 
			
		||||
      FactorGraph<FactorType> temp_remaining;
 | 
			
		||||
      boost::tie(p_C1_B, temp_remaining) = FactorGraph<FactorType>(p_C1_Bred).eliminate(C1_minus_B, function);
 | 
			
		||||
    }
 | 
			
		||||
    sharedConditional p_C2_B; {
 | 
			
		||||
      std::vector<Index> C2_minus_B; {
 | 
			
		||||
        FastSet<Index> C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
 | 
			
		||||
        BOOST_FOREACH(const Index j, *B->conditional()) {
 | 
			
		||||
          C2_minus_B_set.erase(j); }
 | 
			
		||||
        C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
 | 
			
		||||
      }
 | 
			
		||||
      // Factor into C2\B | B.
 | 
			
		||||
      FactorGraph<FactorType> temp_remaining;
 | 
			
		||||
      boost::tie(p_C2_B, temp_remaining) = FactorGraph<FactorType>(p_C2_Bred).eliminate(C2_minus_B, function);
 | 
			
		||||
    }
 | 
			
		||||
    gttoc(Full_root_factoring);
 | 
			
		||||
 | 
			
		||||
    gttic(Variable_joint);
 | 
			
		||||
    // Build joint on all involved variables
 | 
			
		||||
    FactorGraph<FactorType> p_BC1C2;
 | 
			
		||||
    p_BC1C2.push_back(p_B);
 | 
			
		||||
    p_BC1C2.push_back(p_C1_B->toFactor());
 | 
			
		||||
    p_BC1C2.push_back(p_C2_B->toFactor());
 | 
			
		||||
    if(C1 != B)
 | 
			
		||||
      p_BC1C2.push_back(C1->conditional()->toFactor());
 | 
			
		||||
    if(C2 != B)
 | 
			
		||||
      p_BC1C2.push_back(C2->conditional()->toFactor());
 | 
			
		||||
 | 
			
		||||
    // Compute final marginal by eliminating other variables
 | 
			
		||||
    GenericSequentialSolver<FactorType> solver(p_BC1C2);
 | 
			
		||||
    std::vector<Index> js; js.push_back(j1); js.push_back(j2);
 | 
			
		||||
    return solver.jointFactorGraph(js, function);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -176,9 +176,6 @@ namespace gtsam {
 | 
			
		|||
    /** Gather data on all cliques */
 | 
			
		||||
    CliqueData getCliqueData() const;
 | 
			
		||||
 | 
			
		||||
    /** Collect number of cliques with cached shortcuts */
 | 
			
		||||
    size_t numCachedShortcuts() const;
 | 
			
		||||
 | 
			
		||||
    /** Collect number of cliques with cached separator marginals */
 | 
			
		||||
    size_t numCachedSeparatorMarginals() const;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -99,19 +99,6 @@ namespace gtsam {
 | 
			
		|||
    return size;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class DERIVED, class CONDITIONAL>
 | 
			
		||||
  size_t BayesTreeCliqueBase<DERIVED, CONDITIONAL>::numCachedShortcuts() const {
 | 
			
		||||
    if (!cachedShortcut_)
 | 
			
		||||
      return 0;
 | 
			
		||||
 | 
			
		||||
    size_t subtree_count = 1;
 | 
			
		||||
    BOOST_FOREACH(const derived_ptr& child, children_)
 | 
			
		||||
      subtree_count += child->numCachedShortcuts();
 | 
			
		||||
 | 
			
		||||
    return subtree_count;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class DERIVED, class CONDITIONAL>
 | 
			
		||||
  size_t BayesTreeCliqueBase<DERIVED, CONDITIONAL>::numCachedSeparatorMarginals() const {
 | 
			
		||||
| 
						 | 
				
			
			@ -178,111 +165,51 @@ namespace gtsam {
 | 
			
		|||
      derived_ptr B, Eliminate function) const
 | 
			
		||||
  {
 | 
			
		||||
    gttic(BayesTreeCliqueBase_shortcut);
 | 
			
		||||
    // Check if the ShortCut already exists
 | 
			
		||||
    if (!cachedShortcut_) {
 | 
			
		||||
 | 
			
		||||
      gttic(BayesTreeCliqueBase_shortcut_cachemiss);
 | 
			
		||||
      // We only calculate the shortcut when this clique is not B
 | 
			
		||||
      // and when the S\B is not empty
 | 
			
		||||
      std::vector<Index> S_setminus_B = separator_setminus_B(B);
 | 
			
		||||
      if (B.get() != this && !S_setminus_B.empty()) {
 | 
			
		||||
    // We only calculate the shortcut when this clique is not B
 | 
			
		||||
    // and when the S\B is not empty
 | 
			
		||||
    std::vector<Index> S_setminus_B = separator_setminus_B(B);
 | 
			
		||||
    if (B.get() != this && !S_setminus_B.empty()) {
 | 
			
		||||
 | 
			
		||||
        // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
 | 
			
		||||
        derived_ptr parent(parent_.lock());
 | 
			
		||||
        gttoc(BayesTreeCliqueBase_shortcut_cachemiss);
 | 
			
		||||
        gttoc(BayesTreeCliqueBase_shortcut);
 | 
			
		||||
        FactorGraph<FactorType> p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
 | 
			
		||||
        gttic(BayesTreeCliqueBase_shortcut);
 | 
			
		||||
        gttic(BayesTreeCliqueBase_shortcut_cachemiss);
 | 
			
		||||
        p_Cp_B.push_back(parent->conditional()->toFactor()); // P(Fp|Sp)
 | 
			
		||||
      // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph
 | 
			
		||||
      derived_ptr parent(parent_.lock());
 | 
			
		||||
      gttoc(BayesTreeCliqueBase_shortcut);
 | 
			
		||||
      FactorGraph<FactorType> p_Cp_B(parent->shortcut(B, function)); // P(Sp||B)
 | 
			
		||||
      gttic(BayesTreeCliqueBase_shortcut);
 | 
			
		||||
      p_Cp_B.push_back(parent->conditional()->toFactor()); // P(Fp|Sp)
 | 
			
		||||
 | 
			
		||||
        // Determine the variables we want to keepSet, S union B
 | 
			
		||||
        std::vector<Index> keep = shortcut_indices(B, p_Cp_B);
 | 
			
		||||
      // Determine the variables we want to keepSet, S union B
 | 
			
		||||
      std::vector<Index> keep = shortcut_indices(B, p_Cp_B);
 | 
			
		||||
 | 
			
		||||
        // Reduce the variable indices to start at zero
 | 
			
		||||
        gttic(Reduce);
 | 
			
		||||
        const Permutation reduction = internal::createReducingPermutation(p_Cp_B.keys());
 | 
			
		||||
        internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction);
 | 
			
		||||
        BOOST_FOREACH(const boost::shared_ptr<FactorType>& factor, p_Cp_B) {
 | 
			
		||||
          if(factor) factor->reduceWithInverse(inverseReduction); }
 | 
			
		||||
        inverseReduction.applyInverse(keep);
 | 
			
		||||
        gttoc(Reduce);
 | 
			
		||||
      // Reduce the variable indices to start at zero
 | 
			
		||||
      gttic(Reduce);
 | 
			
		||||
      const Permutation reduction = internal::createReducingPermutation(p_Cp_B.keys());
 | 
			
		||||
      internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction);
 | 
			
		||||
      BOOST_FOREACH(const boost::shared_ptr<FactorType>& factor, p_Cp_B) {
 | 
			
		||||
        if(factor) factor->reduceWithInverse(inverseReduction); }
 | 
			
		||||
      inverseReduction.applyInverse(keep);
 | 
			
		||||
      gttoc(Reduce);
 | 
			
		||||
 | 
			
		||||
        // Create solver that will marginalize for us
 | 
			
		||||
        GenericSequentialSolver<FactorType> solver(p_Cp_B);
 | 
			
		||||
      // Create solver that will marginalize for us
 | 
			
		||||
      GenericSequentialSolver<FactorType> solver(p_Cp_B);
 | 
			
		||||
 | 
			
		||||
        // Finally, we only want to have S\B variables in the Bayes net, so
 | 
			
		||||
        size_t nrFrontals = S_setminus_B.size();
 | 
			
		||||
        cachedShortcut_ = //
 | 
			
		||||
            *solver.conditionalBayesNet(keep, nrFrontals, function);
 | 
			
		||||
      // Finally, we only want to have S\B variables in the Bayes net, so
 | 
			
		||||
      size_t nrFrontals = S_setminus_B.size();
 | 
			
		||||
      BayesNet<CONDITIONAL> result = *solver.conditionalBayesNet(keep, nrFrontals, function);
 | 
			
		||||
 | 
			
		||||
        // Undo the reduction
 | 
			
		||||
        gttic(Undo_Reduce);
 | 
			
		||||
        BOOST_FOREACH(const typename boost::shared_ptr<FactorType>& factor, p_Cp_B) {
 | 
			
		||||
          if (factor) factor->permuteWithInverse(reduction); }
 | 
			
		||||
        cachedShortcut_->permuteWithInverse(reduction);
 | 
			
		||||
        gttoc(Undo_Reduce);
 | 
			
		||||
      // Undo the reduction
 | 
			
		||||
      gttic(Undo_Reduce);
 | 
			
		||||
      BOOST_FOREACH(const typename boost::shared_ptr<FactorType>& factor, p_Cp_B) {
 | 
			
		||||
        if (factor) factor->permuteWithInverse(reduction); }
 | 
			
		||||
      result.permuteWithInverse(reduction);
 | 
			
		||||
      gttoc(Undo_Reduce);
 | 
			
		||||
 | 
			
		||||
        assertInvariants();
 | 
			
		||||
      } else {
 | 
			
		||||
        BayesNet<CONDITIONAL> empty;
 | 
			
		||||
        cachedShortcut_ = empty;
 | 
			
		||||
      }
 | 
			
		||||
      assertInvariants();
 | 
			
		||||
 | 
			
		||||
      return result;
 | 
			
		||||
    } else {
 | 
			
		||||
      gttic(BayesTreeCliqueBase_shortcut_cachehit);
 | 
			
		||||
      return BayesNet<CONDITIONAL>();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // return the shortcut P(S||B)
 | 
			
		||||
    return *cachedShortcut_; // return the cached version
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  // P(C) = \int_R P(F|S) P(S|R) P(R)
 | 
			
		||||
  // TODO: Maybe we should integrate given parent marginal P(Cp),
 | 
			
		||||
  // \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
 | 
			
		||||
  // Because the root clique could be very big.
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class DERIVED, class CONDITIONAL>
 | 
			
		||||
  FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
 | 
			
		||||
      DERIVED, CONDITIONAL>::marginal(derived_ptr R, Eliminate function) const
 | 
			
		||||
  {
 | 
			
		||||
    gttic(BayesTreeCliqueBase_marginal);
 | 
			
		||||
    // If we are the root, just return this root
 | 
			
		||||
    // NOTE: immediately cast to a factor graph
 | 
			
		||||
    BayesNet<ConditionalType> bn(R->conditional());
 | 
			
		||||
    if (R.get() == this)
 | 
			
		||||
      return bn;
 | 
			
		||||
 | 
			
		||||
    // Combine P(F|S), P(S|R), and P(R)
 | 
			
		||||
    BayesNet<ConditionalType> p_FSRc = this->shortcut(R, function);
 | 
			
		||||
    p_FSRc.push_front(this->conditional());
 | 
			
		||||
    p_FSRc.push_back(R->conditional());
 | 
			
		||||
    FactorGraph<FactorType> p_FSR = p_FSRc;
 | 
			
		||||
 | 
			
		||||
    assertInvariants();
 | 
			
		||||
 | 
			
		||||
    // Reduce the variable indices to start at zero
 | 
			
		||||
    gttic(Reduce);
 | 
			
		||||
    const Permutation reduction = internal::createReducingPermutation(p_FSR.keys());
 | 
			
		||||
    internal::Reduction inverseReduction = internal::Reduction::CreateAsInverse(reduction);
 | 
			
		||||
    BOOST_FOREACH(const boost::shared_ptr<FactorType>& factor, p_FSR) {
 | 
			
		||||
      factor->reduceWithInverse(inverseReduction); }
 | 
			
		||||
    std::vector<Index> keysFS = conditional_->keys();
 | 
			
		||||
    inverseReduction.applyInverse(keysFS);
 | 
			
		||||
    gttoc(Reduce);
 | 
			
		||||
 | 
			
		||||
    // Eliminate to get the marginal
 | 
			
		||||
    const GenericSequentialSolver<FactorType> solver(p_FSR);
 | 
			
		||||
    FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> result =
 | 
			
		||||
      *solver.jointFactorGraph(keysFS, function);
 | 
			
		||||
 | 
			
		||||
    // Undo the reduction (don't need to undo p_FSR since the FactorGraph conversion no longer references the cached shortcuts)
 | 
			
		||||
    gttic(Undo_Reduce);
 | 
			
		||||
    BOOST_FOREACH(const typename boost::shared_ptr<FactorType>& factor, result) {
 | 
			
		||||
      if (factor) factor->permuteWithInverse(reduction); }
 | 
			
		||||
    gttoc(Undo_Reduce);
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			@ -364,53 +291,6 @@ namespace gtsam {
 | 
			
		|||
    return p_C;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef SHORTCUT_JOINTS
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class DERIVED, class CONDITIONAL>
 | 
			
		||||
  FactorGraph<typename BayesTreeCliqueBase<DERIVED, CONDITIONAL>::FactorType> BayesTreeCliqueBase<
 | 
			
		||||
      DERIVED, CONDITIONAL>::joint(derived_ptr C2, derived_ptr R,
 | 
			
		||||
      Eliminate function) const
 | 
			
		||||
  {
 | 
			
		||||
    gttic(BayesTreeCliqueBase_joint);
 | 
			
		||||
    // For now, assume neither is the root
 | 
			
		||||
 | 
			
		||||
    sharedConditional p_F1_S1 = this->conditional();
 | 
			
		||||
    sharedConditional p_F2_S2 = C2->conditional();
 | 
			
		||||
 | 
			
		||||
    // Combine P(F1|S1), P(S1|R), P(F2|S2), P(S2|R), and P(R)
 | 
			
		||||
    FactorGraph<FactorType> joint;
 | 
			
		||||
    if (!isRoot()) {
 | 
			
		||||
      joint.push_back(p_F1_S1->toFactor()); // P(F1|S1)
 | 
			
		||||
      joint.push_back(shortcut(R, function)); // P(S1|R)
 | 
			
		||||
    }
 | 
			
		||||
    if (!C2->isRoot()) {
 | 
			
		||||
      joint.push_back(p_F2_S2->toFactor()); // P(F2|S2)
 | 
			
		||||
      joint.push_back(C2->shortcut(R, function)); // P(S2|R)
 | 
			
		||||
    }
 | 
			
		||||
    joint.push_back(R->conditional()->toFactor()); // P(R)
 | 
			
		||||
 | 
			
		||||
    // Merge the keys of C1 and C2
 | 
			
		||||
    std::vector<Index> keys12;
 | 
			
		||||
    std::vector<Index> &indices1 = p_F1_S1->keys(), &indices2 = p_F2_S2->keys();
 | 
			
		||||
    std::set_union(indices1.begin(), indices1.end(), //
 | 
			
		||||
        indices2.begin(), indices2.end(), std::back_inserter(keys12));
 | 
			
		||||
 | 
			
		||||
    // Check validity
 | 
			
		||||
    bool cliques_intersect = (keys12.size() < indices1.size() + indices2.size());
 | 
			
		||||
    if (!isRoot() && !C2->isRoot() && cliques_intersect)
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "BayesTreeCliqueBase::joint can only calculate joint if cliques are disjoint\n"
 | 
			
		||||
          "or one of them is the root clique");
 | 
			
		||||
 | 
			
		||||
    // Calculate the marginal
 | 
			
		||||
    assertInvariants();
 | 
			
		||||
    GenericSequentialSolver<FactorType> solver(joint);
 | 
			
		||||
    return *solver.jointFactorGraph(keys12, function);
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************* */
 | 
			
		||||
  template<class DERIVED, class CONDITIONAL>
 | 
			
		||||
  void BayesTreeCliqueBase<DERIVED, CONDITIONAL>::deleteCachedShortcuts() {
 | 
			
		||||
| 
						 | 
				
			
			@ -418,13 +298,13 @@ namespace gtsam {
 | 
			
		|||
    // When a shortcut is requested, all of the shortcuts between it and the
 | 
			
		||||
    // root are also generated. So, if this clique's cached shortcut is set,
 | 
			
		||||
    // recursively call over all child cliques. Otherwise, it is unnecessary.
 | 
			
		||||
    if (cachedShortcut_) {
 | 
			
		||||
    if (cachedSeparatorMarginal_) {
 | 
			
		||||
      BOOST_FOREACH(derived_ptr& child, children_) {
 | 
			
		||||
        child->deleteCachedShortcuts();
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      //Delete CachedShortcut for this clique
 | 
			
		||||
      this->resetCachedShortcut();
 | 
			
		||||
      cachedSeparatorMarginal_ = boost::none;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -79,9 +79,6 @@ namespace gtsam {
 | 
			
		|||
 | 
			
		||||
    /// @}
 | 
			
		||||
 | 
			
		||||
    /// This stores the Cached Shortcut value
 | 
			
		||||
    mutable boost::optional<BayesNet<ConditionalType> > cachedShortcut_;
 | 
			
		||||
 | 
			
		||||
    /// This stores the Cached separator margnal P(S)
 | 
			
		||||
    mutable boost::optional<FactorGraph<FactorType> > cachedSeparatorMarginal_;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -124,9 +121,6 @@ namespace gtsam {
 | 
			
		|||
    /** The size of subtree rooted at this clique, i.e., nr of Cliques */
 | 
			
		||||
    size_t treeSize() const;
 | 
			
		||||
 | 
			
		||||
    /** Collect number of cliques with cached shortcuts in subtree */
 | 
			
		||||
    size_t numCachedShortcuts() const;
 | 
			
		||||
 | 
			
		||||
    /** Collect number of cliques with cached separator marginals */
 | 
			
		||||
    size_t numCachedSeparatorMarginals() const;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -194,34 +188,18 @@ namespace gtsam {
 | 
			
		|||
    /** return the conditional P(S|Root) on the separator given the root */
 | 
			
		||||
    BayesNet<ConditionalType> shortcut(derived_ptr root, Eliminate function) const;
 | 
			
		||||
 | 
			
		||||
    /** return the marginal P(C) of the clique */
 | 
			
		||||
    FactorGraph<FactorType> marginal(derived_ptr root, Eliminate function) const;
 | 
			
		||||
 | 
			
		||||
    /** return the marginal P(S) on the separator */
 | 
			
		||||
    FactorGraph<FactorType> separatorMarginal(derived_ptr root, Eliminate function) const;
 | 
			
		||||
 | 
			
		||||
    /** return the marginal P(C) of the clique, using marginal caching */
 | 
			
		||||
    FactorGraph<FactorType> marginal2(derived_ptr root, Eliminate function) const;
 | 
			
		||||
 | 
			
		||||
#ifdef SHORTCUT_JOINTS
 | 
			
		||||
    /**
 | 
			
		||||
     * return the joint P(C1,C2), where C1==this. TODO: not a method?
 | 
			
		||||
     * Limitation: can only calculate joint if cliques are disjoint or one of them is root
 | 
			
		||||
     */
 | 
			
		||||
    FactorGraph<FactorType> joint(derived_ptr C2, derived_ptr root, Eliminate function) const;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * This deletes the cached shortcuts of all cliques (subtree) below this clique.
 | 
			
		||||
     * This is performed when the bayes tree is modified.
 | 
			
		||||
     */
 | 
			
		||||
    void deleteCachedShortcuts();
 | 
			
		||||
 | 
			
		||||
    /** return cached shortcut of the clique */
 | 
			
		||||
    const boost::optional<BayesNet<ConditionalType> >& cachedShortcut() const {
 | 
			
		||||
      return cachedShortcut_;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const boost::optional<FactorGraph<FactorType> >& cachedSeparatorMarginal() const {
 | 
			
		||||
      return cachedSeparatorMarginal_;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -247,12 +225,6 @@ namespace gtsam {
 | 
			
		|||
    std::vector<Index> shortcut_indices(derived_ptr B,
 | 
			
		||||
        const FactorGraph<FactorType>& p_Cp_B) const;
 | 
			
		||||
 | 
			
		||||
    /// Reset the computed shortcut of this clique. Used by friend BayesTree
 | 
			
		||||
    void resetCachedShortcut() {
 | 
			
		||||
      cachedSeparatorMarginal_ = boost::none;
 | 
			
		||||
      cachedShortcut_ = boost::none;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
  private:
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -141,11 +141,8 @@ typename EliminationTree<FACTOR>::shared_ptr EliminationTree<FACTOR>::Create(
 | 
			
		|||
 | 
			
		||||
  // Hang factors in right places
 | 
			
		||||
  gttic(hang_factors);
 | 
			
		||||
  BOOST_FOREACH(const typename boost::shared_ptr<DERIVEDFACTOR>& derivedFactor, factorGraph) {
 | 
			
		||||
    // Here we upwards-cast to the factor type of this EliminationTree.  This
 | 
			
		||||
    // allows performing symbolic elimination on, for example, GaussianFactors.
 | 
			
		||||
    if(derivedFactor) {
 | 
			
		||||
      sharedFactor factor(derivedFactor);
 | 
			
		||||
  BOOST_FOREACH(const typename boost::shared_ptr<DERIVEDFACTOR>& factor, factorGraph) {
 | 
			
		||||
    if(factor && factor->size() > 0) {
 | 
			
		||||
      Index j = *std::min_element(factor->begin(), factor->end());
 | 
			
		||||
      if(j < structure.size())
 | 
			
		||||
        trees[j]->add(factor);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -120,8 +120,8 @@ namespace gtsam {
 | 
			
		|||
            BOOST_FOREACH(Index var, *factor)
 | 
			
		||||
                    keys.insert(var);
 | 
			
		||||
 | 
			
		||||
    if (keys.size() < 1) throw invalid_argument(
 | 
			
		||||
        "IndexFactor::CombineAndEliminate called on factors with no variables.");
 | 
			
		||||
    if (keys.size() < nrFrontals) throw invalid_argument(
 | 
			
		||||
        "EliminateSymbolic requested to eliminate more variables than exist in graph.");
 | 
			
		||||
 | 
			
		||||
    vector<Index> newKeys(keys.begin(), keys.end());
 | 
			
		||||
    return make_pair(boost::make_shared<IndexConditional>(newKeys, nrFrontals),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -290,10 +290,9 @@ TEST( BayesTree, shortcutCheck )
 | 
			
		|||
  // Check if all the cached shortcuts are cleared
 | 
			
		||||
  rootClique->deleteCachedShortcuts();
 | 
			
		||||
  BOOST_FOREACH(SymbolicBayesTree::sharedClique& clique, allCliques) {
 | 
			
		||||
    bool notCleared = clique->cachedShortcut();
 | 
			
		||||
    bool notCleared = clique->cachedSeparatorMarginal();
 | 
			
		||||
    CHECK( notCleared == false);
 | 
			
		||||
  }
 | 
			
		||||
  EXPECT_LONGS_EQUAL(0, rootClique->numCachedShortcuts());
 | 
			
		||||
  EXPECT_LONGS_EQUAL(0, rootClique->numCachedSeparatorMarginals());
 | 
			
		||||
 | 
			
		||||
//  BOOST_FOREACH(SymbolicBayesTree::sharedClique& clique, allCliques) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -321,6 +321,45 @@ TEST(GaussianBayesTree, simpleMarginal)
 | 
			
		|||
  EXPECT(assert_equal(expected, actual));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(GaussianBayesTree, shortcut_overlapping_separator)
 | 
			
		||||
{
 | 
			
		||||
  // Test computing shortcuts when the separator overlaps.  This previously
 | 
			
		||||
  // would have highlighted a problem where information was duplicated.
 | 
			
		||||
 | 
			
		||||
  // Create factor graph:
 | 
			
		||||
  // f(1,2,5)
 | 
			
		||||
  // f(3,4,5)
 | 
			
		||||
  // f(5,6)
 | 
			
		||||
  // f(6,7)
 | 
			
		||||
  GaussianFactorGraph fg;
 | 
			
		||||
  noiseModel::Diagonal::shared_ptr model = noiseModel::Unit::Create(1);
 | 
			
		||||
  fg.add(1, Matrix_(1,1, 1.0), 3, Matrix_(1,1, 2.0), 5, Matrix_(1,1, 3.0), Vector_(1, 4.0), model);
 | 
			
		||||
  fg.add(1, Matrix_(1,1, 5.0), Vector_(1, 6.0), model);
 | 
			
		||||
  fg.add(2, Matrix_(1,1, 7.0), 4, Matrix_(1,1, 8.0), 5, Matrix_(1,1, 9.0), Vector_(1, 10.0), model);
 | 
			
		||||
  fg.add(2, Matrix_(1,1, 11.0), Vector_(1, 12.0), model);
 | 
			
		||||
  fg.add(5, Matrix_(1,1, 13.0), 6, Matrix_(1,1, 14.0), Vector_(1, 15.0), model);
 | 
			
		||||
  fg.add(6, Matrix_(1,1, 17.0), 7, Matrix_(1,1, 18.0), Vector_(1, 19.0), model);
 | 
			
		||||
  fg.add(7, Matrix_(1,1, 20.0), Vector_(1, 21.0), model);
 | 
			
		||||
 | 
			
		||||
  // Eliminate into BayesTree
 | 
			
		||||
  // c(6,7)
 | 
			
		||||
  // c(5|6)
 | 
			
		||||
  //   c(1,2|5)
 | 
			
		||||
  //   c(3,4|5)
 | 
			
		||||
  GaussianBayesTree bt = *GaussianMultifrontalSolver(fg).eliminate();
 | 
			
		||||
 | 
			
		||||
  GaussianFactorGraph joint = *bt.joint(1,2, EliminateQR);
 | 
			
		||||
 | 
			
		||||
  Matrix expectedJointJ = (Matrix(2,3) <<
 | 
			
		||||
    0, 11, 12,
 | 
			
		||||
    -5, 0, -6
 | 
			
		||||
    ).finished();
 | 
			
		||||
  Matrix actualJointJ = joint.augmentedJacobian();
 | 
			
		||||
 | 
			
		||||
  EXPECT(assert_equal(expectedJointJ, actualJointJ));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue