diff --git a/gtsam/linear/GaussianBayesTree.cpp b/gtsam/linear/GaussianBayesTree.cpp index 13c19bce6..a83475e26 100644 --- a/gtsam/linear/GaussianBayesTree.cpp +++ b/gtsam/linear/GaussianBayesTree.cpp @@ -31,18 +31,37 @@ namespace gtsam { template class BayesTreeCliqueBase; template class BayesTree; - /* ************************************************************************* */ - namespace internal - { - /* ************************************************************************* */ - double logDeterminant(const GaussianBayesTreeClique::shared_ptr& clique, - double& parentSum) { - parentSum += clique->conditional() - ->R() - .diagonal() - .unaryExpr([](double x) { return log(x); }) - .sum(); - return 0; + /* ************************************************************************ */ + namespace internal { + + /** + * @brief Struct to help with traversing the Bayes Tree + * for log-determinant computation. + * Records the data which is passed to the child nodes in pre-order visit. + */ + struct LogDeterminantData { + // Use pointer so we can get the full result after tree traversal + double* logDet; + LogDeterminantData(double* logDet) + : logDet(logDet) {} + }; + /* ************************************************************************ */ + LogDeterminantData& logDeterminant( + const GaussianBayesTreeClique::shared_ptr& clique, + LogDeterminantData& parentSum) { + auto cg = clique->conditional(); + double logDet; + if (cg->get_model()) { + Vector diag = cg->R().diagonal(); + cg->get_model()->whitenInPlace(diag); + logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); + } else { + logDet = + cg->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); + } + // Add the current clique's log-determinant to the overall sum + (*parentSum.logDet) += logDet; + return parentSum; } } // namespace internal @@ -87,7 +106,14 @@ namespace gtsam { return 0.0; } else { double sum = 0.0; - treeTraversal::DepthFirstForest(*this, sum, internal::logDeterminant); + // Store the log-determinant in this struct. + internal::LogDeterminantData rootData(&sum); + // No need to do anything for post-operation. + treeTraversal::no_op visitorPost; + // Limits OpenMP threads if we're mixing TBB and OpenMP + TbbOpenMPMixedScope threadLimiter; + // Traverse the GaussianBayesTree depth first and call logDeterminant on each node. + treeTraversal::DepthFirstForestParallel(*this, rootData, internal::logDeterminant, visitorPost); return sum; } } @@ -106,7 +132,3 @@ namespace gtsam { } // \namespace gtsam - - - -