Merge pull request #1327 from borglab/fix/gbt-determinant
						commit
						a281e1a26e
					
				|  | @ -31,18 +31,37 @@ namespace gtsam { | |||
|   template class BayesTreeCliqueBase<GaussianBayesTreeClique, GaussianFactorGraph>; | ||||
|   template class BayesTree<GaussianBayesTreeClique>; | ||||
| 
 | ||||
|   /* ************************************************************************* */ | ||||
|   namespace internal | ||||
|   { | ||||
|     /* ************************************************************************* */ | ||||
|   double logDeterminant(const GaussianBayesTreeClique::shared_ptr& clique, | ||||
|                         double& parentSum) { | ||||
|     parentSum += clique->conditional() | ||||
|                      ->R() | ||||
|                      .diagonal() | ||||
|                      .unaryExpr([](double x) { return log(x); }) | ||||
|                      .sum(); | ||||
|     return 0; | ||||
|   /* ************************************************************************ */ | ||||
|   namespace internal { | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Struct to help with traversing the Bayes Tree | ||||
|    * for log-determinant computation. | ||||
|    * Records the data which is passed to the child nodes in pre-order visit. | ||||
|    */ | ||||
|   struct LogDeterminantData { | ||||
|     // Use pointer so we can get the full result after tree traversal
 | ||||
|     double* logDet; | ||||
|     LogDeterminantData(double* logDet) | ||||
|         : logDet(logDet) {} | ||||
|   }; | ||||
|   /* ************************************************************************ */ | ||||
|   LogDeterminantData& logDeterminant( | ||||
|       const GaussianBayesTreeClique::shared_ptr& clique, | ||||
|       LogDeterminantData& parentSum) { | ||||
|     auto cg = clique->conditional(); | ||||
|     double logDet; | ||||
|     if (cg->get_model()) { | ||||
|       Vector diag = cg->R().diagonal(); | ||||
|       cg->get_model()->whitenInPlace(diag); | ||||
|       logDet = diag.unaryExpr([](double x) { return log(x); }).sum(); | ||||
|     } else { | ||||
|       logDet = | ||||
|           cg->R().diagonal().unaryExpr([](double x) { return log(x); }).sum(); | ||||
|     } | ||||
|     // Add the current clique's log-determinant to the overall sum
 | ||||
|     (*parentSum.logDet) += logDet; | ||||
|     return parentSum; | ||||
|   } | ||||
|   }  // namespace internal
 | ||||
| 
 | ||||
|  | @ -87,7 +106,14 @@ namespace gtsam { | |||
|       return 0.0; | ||||
|     } else { | ||||
|       double sum = 0.0; | ||||
|       treeTraversal::DepthFirstForest(*this, sum, internal::logDeterminant); | ||||
|       // Store the log-determinant in this struct.
 | ||||
|       internal::LogDeterminantData rootData(&sum); | ||||
|       // No need to do anything for post-operation.
 | ||||
|       treeTraversal::no_op visitorPost; | ||||
|       // Limits OpenMP threads if we're mixing TBB and OpenMP
 | ||||
|       TbbOpenMPMixedScope threadLimiter; | ||||
|       // Traverse the GaussianBayesTree depth first and call logDeterminant on each node.
 | ||||
|       treeTraversal::DepthFirstForestParallel(*this, rootData, internal::logDeterminant, visitorPost); | ||||
|       return sum; | ||||
|     } | ||||
|   } | ||||
|  | @ -106,7 +132,3 @@ namespace gtsam { | |||
| 
 | ||||
| 
 | ||||
| } // \namespace gtsam
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,18 +15,18 @@ | |||
|  * @author Kai Ni | ||||
|  */ | ||||
| 
 | ||||
| #include <iostream> | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| 
 | ||||
| #include <boost/assign/list_of.hpp> | ||||
| #include <boost/assign/std/list.hpp> // for operator += | ||||
| #include <boost/assign/std/set.hpp> // for operator += | ||||
| 
 | ||||
| #include <gtsam/base/debug.h> | ||||
| #include <gtsam/base/numericalDerivative.h> | ||||
| #include <gtsam/linear/GaussianJunctionTree.h> | ||||
| #include <gtsam/inference/Symbol.h> | ||||
| #include <gtsam/linear/GaussianBayesTree.h> | ||||
| #include <gtsam/linear/GaussianConditional.h> | ||||
| #include <gtsam/linear/GaussianJunctionTree.h> | ||||
| 
 | ||||
| #include <boost/assign/list_of.hpp> | ||||
| #include <boost/assign/std/list.hpp>  // for operator += | ||||
| #include <boost/assign/std/set.hpp>   // for operator += | ||||
| #include <iostream> | ||||
| 
 | ||||
| using namespace boost::assign; | ||||
| using namespace std::placeholders; | ||||
|  | @ -321,6 +321,35 @@ TEST(GaussianBayesTree, determinant_and_smallestEigenvalue) { | |||
|   EXPECT_DOUBLES_EQUAL(expectedDeterminant,actualDeterminant,expectedDeterminant*1e-6);// relative tolerance
 | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| /// Test to expose bug in GaussianBayesTree::logDeterminant.
 | ||||
| TEST(GaussianBayesTree, LogDeterminant) { | ||||
|   using symbol_shorthand::L; | ||||
|   using symbol_shorthand::X; | ||||
| 
 | ||||
|   // Create a factor graph that will result in
 | ||||
|   // a bayes tree with at least 2 nodes.
 | ||||
|   GaussianFactorGraph fg; | ||||
|   Key x1 = X(1), x2 = X(2), l1 = L(1); | ||||
|   SharedDiagonal unit2 = noiseModel::Unit::Create(2); | ||||
|   fg += JacobianFactor(x1, 10 * I_2x2, -1.0 * Vector2::Ones(), unit2); | ||||
|   fg += JacobianFactor(x2, 10 * I_2x2, x1, -10 * I_2x2, Vector2(2.0, -1.0), | ||||
|                        unit2); | ||||
|   fg += JacobianFactor(l1, 5 * I_2x2, x1, -5 * I_2x2, Vector2(0.0, 1.0), unit2); | ||||
|   fg += | ||||
|       JacobianFactor(x2, -5 * I_2x2, l1, 5 * I_2x2, Vector2(-1.0, 1.5), unit2); | ||||
|   fg += JacobianFactor(x3, 10 * I_2x2, x2, -10 * I_2x2, Vector2(2.0, -1.0), | ||||
|                        unit2); | ||||
|   fg += JacobianFactor(x3, 10 * I_2x2, -1.0 * Vector2::Ones(), unit2); | ||||
| 
 | ||||
|   // create corresponding Bayes net and Bayes tree:
 | ||||
|   boost::shared_ptr<gtsam::GaussianBayesNet> bn = fg.eliminateSequential(); | ||||
|   boost::shared_ptr<gtsam::GaussianBayesTree> bt = fg.eliminateMultifrontal(); | ||||
| 
 | ||||
|   // Test logDeterminant
 | ||||
|   EXPECT_DOUBLES_EQUAL(bn->logDeterminant(), bt->logDeterminant(), 1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { TestResult tr; return TestRegistry::runAllTests(tr);} | ||||
| /* ************************************************************************* */ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue