support for varying normalizers in GaussianMixtureFactor

release/4.3a0
Varun Agrawal 2024-08-20 16:29:40 -04:00
parent eef9765e4a
commit ea104c4b83
2 changed files with 92 additions and 9 deletions

View File

@ -28,11 +28,86 @@
namespace gtsam { namespace gtsam {
/**
* @brief Helper function to correct the [A|b] matrices in the factor components
* with the normalizer values.
* This is done by storing the normalizer value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactor shared pointers.
* @param varyingNormalizers Flag indicating the normalizers are different for
* each component.
* @return GaussianMixtureFactor::Factors
*/
GaussianMixtureFactor::Factors correct(
const GaussianMixtureFactor::Factors &factors, bool varyingNormalizers) {
if (!varyingNormalizers) {
return factors;
}
// First compute all the sqrt(|2 pi Sigma|) terms
auto computeNormalizers = [](const GaussianMixtureFactor::sharedFactor &gf) {
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
// If we have, say, a Hessian factor, then no need to do anything
if (!jf) return 0.0;
auto model = jf->get_model();
// If there is no noise model, there is nothing to do.
if (!model) {
return 0.0;
}
// Since noise models are Gaussian, we can get the logDeterminant using the
// same trick as in GaussianConditional
double logDetR =
model->R().diagonal().unaryExpr([](double x) { return log(x); }).sum();
double logDeterminantSigma = -2.0 * logDetR;
size_t n = model->dim();
constexpr double log2pi = 1.8378770664093454835606594728112;
return n * log2pi + logDeterminantSigma;
};
AlgebraicDecisionTree<Key> log_normalizers =
DecisionTree<Key, double>(factors, computeNormalizers);
// Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers.
double min_log_normalizer = log_normalizers.min();
log_normalizers = log_normalizers.apply(
[&min_log_normalizer](double n) { return n - min_log_normalizer; });
// Finally, update the [A|b] matrices.
auto update = [&log_normalizers](
const Assignment<Key> &assignment,
const GaussianMixtureFactor::sharedFactor &gf) {
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
if (!jf) return gf;
// If there is no noise model, there is nothing to do.
if (!jf->get_model()) return gf;
// If the log_normalizer is 0, do nothing
if (log_normalizers(assignment) == 0.0) return gf;
GaussianFactorGraph gfg;
gfg.push_back(jf);
Vector c(1);
c << std::sqrt(log_normalizers(assignment));
auto constantFactor = std::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return std::dynamic_pointer_cast<GaussianFactor>(
std::make_shared<JacobianFactor>(gfg));
};
return factors.apply(update);
}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors) const Factors &factors,
: Base(continuousKeys, discreteKeys), factors_(factors) {} bool varyingNormalizers)
: Base(continuousKeys, discreteKeys),
factors_(correct(factors, varyingNormalizers)) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
@ -54,7 +129,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s, void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); std::cout << (s.empty() ? "" : s + "\n");
std::cout << "GaussianMixtureFactor" << std::endl;
HybridFactor::print("", formatter);
std::cout << "{\n"; std::cout << "{\n";
if (factors_.empty()) { if (factors_.empty()) {
std::cout << " empty" << std::endl; std::cout << " empty" << std::endl;

View File

@ -82,10 +82,13 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* their cardinalities. * their cardinalities.
* @param factors The decision tree of Gaussian factors stored as the mixture * @param factors The decision tree of Gaussian factors stored as the mixture
* density. * density.
* @param varyingNormalizers Flag indicating factor components have varying
* normalizer values.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors); const Factors &factors,
bool varyingNormalizers = false);
/** /**
* @brief Construct a new GaussianMixtureFactor object using a vector of * @brief Construct a new GaussianMixtureFactor object using a vector of
@ -94,12 +97,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys Vector of keys for continuous factors. * @param continuousKeys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys. * @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers. * @param factors Vector of gaussian factor shared pointers.
* @param varyingNormalizers Flag indicating factor components have varying
* normalizer values.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const std::vector<sharedFactor> &factors) const std::vector<sharedFactor> &factors,
bool varyingNormalizers = false)
: GaussianMixtureFactor(continuousKeys, discreteKeys, : GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {} Factors(discreteKeys, factors),
varyingNormalizers) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -107,9 +114,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print( void print(const std::string &s = "", const KeyFormatter &formatter =
const std::string &s = "GaussianMixtureFactor\n", DefaultKeyFormatter) const override;
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API /// @name Standard API