remove augment method in favor of conversion

release/4.3a0
Varun Agrawal 2024-08-22 14:43:23 -04:00
parent 98a2668fb1
commit cd9ee08457
3 changed files with 23 additions and 78 deletions

View File

@ -200,27 +200,24 @@ std::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) { conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
const auto likelihood_m = conditional->likelihood(given); const auto likelihood_m = conditional->likelihood(given);
return likelihood_m; const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
return likelihood_m;
} else {
// Add a constant factor to the likelihood in case the noise models
// are not all equal.
GaussianFactorGraph gfg;
gfg.push_back(likelihood_m);
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = std::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return std::make_shared<JacobianFactor>(gfg);
}
}); });
// First compute all the sqrt(|2 pi Sigma|) terms
auto computeLogNormalizers = [](const GaussianFactor::shared_ptr &gf) {
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
// If we have, say, a Hessian factor, then no need to do anything
if (!jf) return 0.0;
auto model = jf->get_model();
// If there is no noise model, there is nothing to do.
if (!model) {
return 0.0;
}
return ComputeLogNormalizer(model);
};
AlgebraicDecisionTree<Key> log_normalizers =
DecisionTree<Key, double>(likelihoods, computeLogNormalizers);
return std::make_shared<GaussianMixtureFactor>( return std::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods, log_normalizers); continuousParentKeys, discreteParentKeys, likelihoods);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -28,55 +28,11 @@
namespace gtsam { namespace gtsam {
/**
* @brief Helper function to augment the [A|b] matrices in the factor components
* with the normalizer values.
* This is done by storing the normalizer value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactor shared pointers.
* @param logNormalizers Tree of log-normalizers corresponding to each
* Gaussian factor in factors.
* @return GaussianMixtureFactor::Factors
*/
GaussianMixtureFactor::Factors augment(
const GaussianMixtureFactor::Factors &factors,
const AlgebraicDecisionTree<Key> &logNormalizers) {
// Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers.
double min_log_normalizer = logNormalizers.min();
AlgebraicDecisionTree<Key> log_normalizers = logNormalizers.apply(
[&min_log_normalizer](double n) { return n - min_log_normalizer; });
// Finally, update the [A|b] matrices.
auto update = [&log_normalizers](
const Assignment<Key> &assignment,
const GaussianMixtureFactor::sharedFactor &gf) {
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
if (!jf) return gf;
// If the log_normalizer is 0, do nothing
if (log_normalizers(assignment) == 0.0) return gf;
GaussianFactorGraph gfg;
gfg.push_back(jf);
Vector c(1);
c << std::sqrt(log_normalizers(assignment));
auto constantFactor = std::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return std::dynamic_pointer_cast<GaussianFactor>(
std::make_shared<JacobianFactor>(gfg));
};
return factors.apply(update);
}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor( GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors, const AlgebraicDecisionTree<Key> &logNormalizers) const Factors &factors)
: Base(continuousKeys, discreteKeys), : Base(continuousKeys, discreteKeys), factors_(factors) {}
factors_(augment(factors, logNormalizers)) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {

View File

@ -82,14 +82,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* their cardinalities. * their cardinalities.
* @param factors The decision tree of Gaussian factors stored as the mixture * @param factors The decision tree of Gaussian factors stored as the mixture
* density. * density.
* @param logNormalizers Tree of log-normalizers corresponding to each
* Gaussian factor in factors.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors, const Factors &factors);
const AlgebraicDecisionTree<Key> &logNormalizers =
AlgebraicDecisionTree<Key>(0.0));
/** /**
* @brief Construct a new GaussianMixtureFactor object using a vector of * @brief Construct a new GaussianMixtureFactor object using a vector of
@ -98,16 +94,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys Vector of keys for continuous factors. * @param continuousKeys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys. * @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers. * @param factors Vector of gaussian factor shared pointers.
* @param logNormalizers Tree of log-normalizers corresponding to each
* Gaussian factor in factors.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const std::vector<sharedFactor> &factors, const std::vector<sharedFactor> &factors)
const AlgebraicDecisionTree<Key> &logNormalizers =
AlgebraicDecisionTree<Key>(0.0))
: GaussianMixtureFactor(continuousKeys, discreteKeys, : GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors), logNormalizers) {} Factors(discreteKeys, factors)) {}
/// @} /// @}
/// @name Testable /// @name Testable