fix bug in GaussianBayesTree::logDeterminant
							parent
							
								
									c709932f98
								
							
						
					
					
						commit
						3dcf9d8da8
					
				|  | @ -31,18 +31,37 @@ namespace gtsam { | |||
|   template class BayesTreeCliqueBase<GaussianBayesTreeClique, GaussianFactorGraph>; | ||||
|   template class BayesTree<GaussianBayesTreeClique>; | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   namespace internal | ||||
|   { | ||||
|     /* ************************************************************************* */ | ||||
|   double logDeterminant(const GaussianBayesTreeClique::shared_ptr& clique, | ||||
|                         double& parentSum) { | ||||
|     parentSum += clique->conditional() | ||||
|                      ->R() | ||||
|                      .diagonal() | ||||
|                      .unaryExpr([](double x) { return log(x); }) | ||||
|                      .sum(); | ||||
|     return 0; | ||||
|   /* ************************************************************************ */ | ||||
|   namespace internal { | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Struct to help with traversing the Bayes Tree | ||||
|    * for log-determinant computation. | ||||
|    * Records the data which is passed to the child nodes in pre-order visit. | ||||
|    */ | ||||
|   struct LogDeterminantData { | ||||
|     // Use pointer so we can get the full result after tree traversal
 | ||||
|     double* logDet; | ||||
|     LogDeterminantData(double* logDet) | ||||
|         : logDet(logDet) {} | ||||
|   }; | ||||
|   /* ************************************************************************ */ | ||||
|   LogDeterminantData& logDeterminant( | ||||
|       const GaussianBayesTreeClique::shared_ptr& clique, | ||||
|       LogDeterminantData& parentSum) { | ||||
|     auto cg = clique->conditional(); | ||||
|     double logDet; | ||||
|     if (cg->get_model()) { | ||||
|       Vector diag = cg->R().diagonal(); | ||||
|       cg->get_model()->whitenInPlace(diag); | ||||
|       logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); | ||||
|     } else { | ||||
|       logDet = | ||||
|           cg->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); | ||||
|     } | ||||
|     // Add the current clique's log-determinant to the overall sum
 | ||||
|     (*parentSum.logDet) += logDet; | ||||
|     return parentSum; | ||||
|   } | ||||
|   }  // namespace internal
 | ||||
| 
 | ||||
|  | @ -87,7 +106,14 @@ namespace gtsam { | |||
|       return 0.0; | ||||
|     } else { | ||||
|       double sum = 0.0; | ||||
|       treeTraversal::DepthFirstForest(*this, sum, internal::logDeterminant); | ||||
|       // Store the log-determinant in this struct.
 | ||||
|       internal::LogDeterminantData rootData(&sum); | ||||
|       // No need to do anything for post-operation.
 | ||||
|       treeTraversal::no_op visitorPost; | ||||
|       // Limits OpenMP threads if we're mixing TBB and OpenMP
 | ||||
|       TbbOpenMPMixedScope threadLimiter; | ||||
|       // Traverse the GaussianBayesTree depth first and call logDeterminant on each node.
 | ||||
|       treeTraversal::DepthFirstForestParallel(*this, rootData, internal::logDeterminant, visitorPost); | ||||
|       return sum; | ||||
|     } | ||||
|   } | ||||
|  | @ -106,7 +132,3 @@ namespace gtsam { | |||
| 
 | ||||
| 
 | ||||
| } // \namespace gtsam
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue