Merge pull request #1345 from borglab/varun/conditional-log-det

release/4.3a0
Varun Agrawal 2022-12-23 09:30:45 -05:00 committed by GitHub
commit 1ab922b253
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 18 deletions

View File

@ -217,14 +217,7 @@ namespace gtsam {
double GaussianBayesNet::logDeterminant() const {
double logDet = 0.0;
for (const sharedConditional& cg : *this) {
if (cg->get_model()) {
Vector diag = cg->R().diagonal();
cg->get_model()->whitenInPlace(diag);
logDet += diag.unaryExpr([](double x) { return log(x); }).sum();
} else {
logDet +=
cg->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
}
logDet += cg->logDeterminant();
}
return logDet;
}

View File

@ -43,7 +43,7 @@ double logDeterminant(const typename BAYESTREE::sharedClique& clique) {
double result = 0.0;
// this clique
result += clique->conditional()->R().diagonal().unaryExpr(std::ptr_fun<double,double>(log)).sum();
result += clique->conditional()->logDeterminant();
// sum of children
for(const typename BAYESTREE::sharedClique& child: clique->children_)

View File

@ -50,15 +50,7 @@ namespace gtsam {
const GaussianBayesTreeClique::shared_ptr& clique,
LogDeterminantData& parentSum) {
auto cg = clique->conditional();
double logDet;
if (cg->get_model()) {
Vector diag = cg->R().diagonal();
cg->get_model()->whitenInPlace(diag);
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
} else {
logDet =
cg->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
}
double logDet = cg->logDeterminant();
// Add the current clique's log-determinant to the overall sum
(*parentSum.logDet) += logDet;
return parentSum;

View File

@ -155,6 +155,20 @@ namespace gtsam {
}
}
/* ************************************************************************* */
double GaussianConditional::logDeterminant() const {
double logDet;
if (this->get_model()) {
Vector diag = this->R().diagonal();
this->get_model()->whitenInPlace(diag);
logDet = diag.unaryExpr([](double x) { return log(x); }).sum();
} else {
logDet =
this->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
}
return logDet;
}
/* ************************************************************************* */
VectorValues GaussianConditional::solve(const VectorValues& x) const {
// Concatenate all vector values that correspond to parent variables

View File

@ -133,6 +133,28 @@ namespace gtsam {
/** Get a view of the r.h.s. vector d */
const constBVector d() const { return BaseFactor::getb(); }
/**
* @brief Compute the log determinant of the Gaussian conditional.
* The determinant is computed using the R matrix, which is upper
* triangular.
* For numerical stability, the determinant is computed in log
* form, so it is a summation rather than a multiplication.
*
* @return double
*/
double logDeterminant() const;
/**
* @brief Compute the determinant of the conditional from the
* upper-triangular R matrix.
*
* The determinant is computed in log form (hence summation) for numerical
* stability and then exponentiated.
*
* @return double
*/
double determinant() const { return exp(this->logDeterminant()); }
/**
* Solves a conditional Gaussian and writes the solution into the entries of
* \c x for each frontal variable of the conditional. The parents are