make GaussianMixtureFactor store the normalizing constant as well

release/4.3a0
Varun Agrawal 2022-12-30 19:28:42 +05:30
parent 6f8a23fe34
commit 22c758221d
4 changed files with 66 additions and 37 deletions

View File

@ -150,7 +150,8 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
return std::make_pair(conditional->likelihood(frontals),
0.5 * conditional->logDeterminant());
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);

View File

@ -29,8 +29,11 @@ namespace gtsam {
/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return std::make_pair(gf, 0.0);
}) {}
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
@ -44,9 +47,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const GaussianFactor::shared_ptr &f1,
const GaussianFactor::shared_ptr &f2) {
return f1->equals(*f2, tol);
[tol](const GaussianMixtureFactor::FactorAndLogZ &f1,
const GaussianMixtureFactor::FactorAndLogZ &f2) {
return f1.first->equals(*(f2.first), tol);
});
}
@ -60,7 +63,8 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
[&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string {
auto gf = gf_z.first;
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
@ -75,8 +79,10 @@ void GaussianMixtureFactor::print(const std::string &s,
}
/* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
return factors_;
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() {
// Unzip to tree of Gaussian factors and tree of log-constants,
// and return the first tree.
return unzip(factors_).first;
}
/* *******************************************************************************/
@ -95,9 +101,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
GaussianFactorGraph result;
result.push_back(factor);
result.push_back(factor_z.first);
return result;
};
return {factors_, wrap};
@ -108,8 +114,11 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousValues);
[continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
GaussianFactor::shared_ptr factor;
double log_z;
std::tie(factor, log_z) = factor_z;
return factor->error(continuousValues) + log_z;
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
@ -120,8 +129,10 @@ double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
GaussianFactor::shared_ptr factor;
double log_z;
std::tie(factor, log_z) = factors_(discreteValues);
return factor->error(continuousValues) + log_z;
}
} // namespace gtsam

View File

@ -54,8 +54,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using Sum = DecisionTree<Key, GaussianFactorGraph>;
/// typedef for Decision Tree of Gaussian Factors
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
/// typedef of pair of Gaussian factor and log of normalizing constant.
using FactorAndLogZ = std::pair<GaussianFactor::shared_ptr, double>;
/// typedef for Decision Tree of Gaussian Factors and log-constant.
using Factors = DecisionTree<Key, FactorAndLogZ>;
using Mixture = DecisionTree<Key, GaussianFactor::shared_ptr>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
@ -87,7 +90,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors);
const Mixture &factors);
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors_and_z)
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
/**
* @brief Construct a new GaussianMixtureFactor object using a vector of
@ -101,7 +109,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}
Mixture(discreteKeys, factors)) {}
/// @}
/// @name Testable
@ -115,7 +123,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
/// @}
/// Getter for the underlying Gaussian Factor Decision Tree.
const Factors &factors();
const Mixture factors();
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while

View File

@ -204,16 +204,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = GaussianFactorGraph::EliminationResult;
using EliminationPair =
std::pair<boost::shared_ptr<GaussianConditional>,
std::pair<boost::shared_ptr<GaussianFactor>, double>>;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
// This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph)
-> GaussianFactorGraph::EliminationResult {
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
if (graph.empty()) {
return {nullptr, nullptr};
return {nullptr, std::make_pair(nullptr, 0.0)};
}
#ifdef HYBRID_TIMING
@ -222,17 +223,21 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional
keysOfEliminated = result.first->keys();
keysOfSeparator = result.second->keys();
keysOfEliminated = conditional_factor.first->keys();
keysOfSeparator = conditional_factor.second->keys();
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
std::pair<boost::shared_ptr<GaussianConditional>,
std::pair<boost::shared_ptr<GaussianFactor>, double>>
result = std::make_pair(conditional_factor.first,
std::make_pair(conditional_factor.second, 0.0));
return result;
};
@ -257,16 +262,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) {
VectorValues empty_values;
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
// This is the probability q(μ) at the MLE point.
double error =
0.5 * std::abs(factor->augmentedInformation().determinant());
return std::exp(-error);
}
};
auto factorProb =
[&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
if (!factor_z.first) {
return 0.0; // If nullptr, return 0.0 probability
} else {
GaussianFactor::shared_ptr factor = factor_z.first;
double log_z = factor_z.second;
// This is the probability q(μ) at the MLE point.
double error =
0.5 * std::abs(factor->augmentedInformation().determinant()) +
log_z;
return std::exp(-error);
}
};
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
auto discreteFactor =