fix bug in GaussianBayesTree::logDeterminant

release/4.3a0
Varun Agrawal 2022-11-14 10:54:03 -05:00
parent c709932f98
commit 3dcf9d8da8
1 changed files with 39 additions and 17 deletions

View File

@ -31,18 +31,37 @@ namespace gtsam {
template class BayesTreeCliqueBase<GaussianBayesTreeClique, GaussianFactorGraph>;
template class BayesTree<GaussianBayesTreeClique>;
/* ************************************************************************* */
namespace internal
{
/* ************************************************************************* */
double logDeterminant(const GaussianBayesTreeClique::shared_ptr& clique,
double& parentSum) {
parentSum += clique->conditional()
->R()
.diagonal()
.unaryExpr([](double x) { return log(x); })
.sum();
return 0;
/* ************************************************************************ */
namespace internal {
/**
* @brief Struct to help with traversing the Bayes Tree
* for log-determinant computation.
* Records the data which is passed to the child nodes in pre-order visit.
*/
struct LogDeterminantData {
// Use pointer so we can get the full result after tree traversal
double* logDet;
LogDeterminantData(double* logDet)
: logDet(logDet) {}
};
/* ************************************************************************ */
LogDeterminantData& logDeterminant(
const GaussianBayesTreeClique::shared_ptr& clique,
LogDeterminantData& parentSum) {
auto cg = clique->conditional();
double logDet;
if (cg->get_model()) {
Vector diag = cg->R().diagonal();
cg->get_model()->whitenInPlace(diag);
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
} else {
logDet =
cg->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
}
// Add the current clique's log-determinant to the overall sum
(*parentSum.logDet) += logDet;
return parentSum;
}
} // namespace internal
@ -87,7 +106,14 @@ namespace gtsam {
return 0.0;
} else {
double sum = 0.0;
treeTraversal::DepthFirstForest(*this, sum, internal::logDeterminant);
// Store the log-determinant in this struct.
internal::LogDeterminantData rootData(&sum);
// No need to do anything for post-operation.
treeTraversal::no_op visitorPost;
// Limits OpenMP threads if we're mixing TBB and OpenMP
TbbOpenMPMixedScope threadLimiter;
// Traverse the GaussianBayesTree depth first and call logDeterminant on each node.
treeTraversal::DepthFirstForestParallel(*this, rootData, internal::logDeterminant, visitorPost);
return sum;
}
}
@ -106,7 +132,3 @@ namespace gtsam {
} // \namespace gtsam