make GaussianMixtureFactor store the normalizing constant as well
parent
6f8a23fe34
commit
22c758221d
|
|
@ -150,7 +150,8 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
|||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->likelihood(frontals);
|
||||
return std::make_pair(conditional->likelihood(frontals),
|
||||
0.5 * conditional->logDeterminant());
|
||||
});
|
||||
return boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||
|
|
|
|||
|
|
@ -29,8 +29,11 @@ namespace gtsam {
|
|||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const Factors &factors)
|
||||
: Base(continuousKeys, discreteKeys), factors_(factors) {}
|
||||
const Mixture &factors)
|
||||
: Base(continuousKeys, discreteKeys),
|
||||
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
|
||||
return std::make_pair(gf, 0.0);
|
||||
}) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||
|
|
@ -44,9 +47,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
|||
// Check the base and the factors:
|
||||
return Base::equals(*e, tol) &&
|
||||
factors_.equals(e->factors_,
|
||||
[tol](const GaussianFactor::shared_ptr &f1,
|
||||
const GaussianFactor::shared_ptr &f2) {
|
||||
return f1->equals(*f2, tol);
|
||||
[tol](const GaussianMixtureFactor::FactorAndLogZ &f1,
|
||||
const GaussianMixtureFactor::FactorAndLogZ &f2) {
|
||||
return f1.first->equals(*(f2.first), tol);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -60,7 +63,8 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
} else {
|
||||
factors_.print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
|
||||
[&](const GaussianMixtureFactor::FactorAndLogZ &gf_z) -> std::string {
|
||||
auto gf = gf_z.first;
|
||||
RedirectCout rd;
|
||||
std::cout << ":\n";
|
||||
if (gf && !gf->empty()) {
|
||||
|
|
@ -75,8 +79,10 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
|
||||
return factors_;
|
||||
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() {
|
||||
// Unzip to tree of Gaussian factors and tree of log-constants,
|
||||
// and return the first tree.
|
||||
return unzip(factors_).first;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
@ -95,9 +101,9 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
|||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||
const {
|
||||
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
|
||||
auto wrap = [](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
result.push_back(factor_z.first);
|
||||
return result;
|
||||
};
|
||||
return {factors_, wrap};
|
||||
|
|
@ -108,8 +114,11 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
|||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianFactor::shared_ptr &factor) {
|
||||
return factor->error(continuousValues);
|
||||
[continuousValues](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
|
||||
GaussianFactor::shared_ptr factor;
|
||||
double log_z;
|
||||
std::tie(factor, log_z) = factor_z;
|
||||
return factor->error(continuousValues) + log_z;
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
|
|
@ -120,8 +129,10 @@ double GaussianMixtureFactor::error(
|
|||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto factor = factors_(discreteValues);
|
||||
return factor->error(continuousValues);
|
||||
GaussianFactor::shared_ptr factor;
|
||||
double log_z;
|
||||
std::tie(factor, log_z) = factors_(discreteValues);
|
||||
return factor->error(continuousValues) + log_z;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -54,8 +54,11 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
|
||||
/// typedef for Decision Tree of Gaussian Factors
|
||||
using Factors = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
||||
/// typedef of pair of Gaussian factor and log of normalizing constant.
|
||||
using FactorAndLogZ = std::pair<GaussianFactor::shared_ptr, double>;
|
||||
/// typedef for Decision Tree of Gaussian Factors and log-constant.
|
||||
using Factors = DecisionTree<Key, FactorAndLogZ>;
|
||||
using Mixture = DecisionTree<Key, GaussianFactor::shared_ptr>;
|
||||
|
||||
private:
|
||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||
|
|
@ -87,7 +90,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
*/
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const Factors &factors);
|
||||
const Mixture &factors);
|
||||
|
||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const Factors &factors_and_z)
|
||||
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
|
||||
|
||||
/**
|
||||
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
||||
|
|
@ -101,7 +109,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
const DiscreteKeys &discreteKeys,
|
||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
||||
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
||||
Factors(discreteKeys, factors)) {}
|
||||
Mixture(discreteKeys, factors)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
|
@ -115,7 +123,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
/// @}
|
||||
|
||||
/// Getter for the underlying Gaussian Factor Decision Tree.
|
||||
const Factors &factors();
|
||||
const Mixture factors();
|
||||
|
||||
/**
|
||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||
|
|
|
|||
|
|
@ -204,16 +204,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
};
|
||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||
|
||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
||||
using EliminationPair =
|
||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
std::pair<boost::shared_ptr<GaussianFactor>, double>>;
|
||||
|
||||
KeyVector keysOfEliminated; // Not the ordering
|
||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||
|
||||
// This is the elimination method on the leaf nodes
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
||||
-> GaussianFactorGraph::EliminationResult {
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
|
||||
if (graph.empty()) {
|
||||
return {nullptr, nullptr};
|
||||
return {nullptr, std::make_pair(nullptr, 0.0)};
|
||||
}
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
|
|
@ -222,17 +223,21 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
boost::shared_ptr<GaussianFactor>>
|
||||
result = EliminatePreferCholesky(graph, frontalKeys);
|
||||
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
|
||||
|
||||
// Initialize the keysOfEliminated to be the keys of the
|
||||
// eliminated GaussianConditional
|
||||
keysOfEliminated = result.first->keys();
|
||||
keysOfSeparator = result.second->keys();
|
||||
keysOfEliminated = conditional_factor.first->keys();
|
||||
keysOfSeparator = conditional_factor.second->keys();
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
gttoc_(hybrid_eliminate);
|
||||
#endif
|
||||
|
||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
std::pair<boost::shared_ptr<GaussianFactor>, double>>
|
||||
result = std::make_pair(conditional_factor.first,
|
||||
std::make_pair(conditional_factor.second, 0.0));
|
||||
return result;
|
||||
};
|
||||
|
||||
|
|
@ -257,16 +262,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
// DiscreteFactor, with the error for each discrete choice.
|
||||
if (keysOfSeparator.empty()) {
|
||||
VectorValues empty_values;
|
||||
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
|
||||
if (!factor) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
double error =
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant());
|
||||
return std::exp(-error);
|
||||
}
|
||||
};
|
||||
auto factorProb =
|
||||
[&](const GaussianMixtureFactor::FactorAndLogZ &factor_z) {
|
||||
if (!factor_z.first) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
GaussianFactor::shared_ptr factor = factor_z.first;
|
||||
double log_z = factor_z.second;
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
double error =
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant()) +
|
||||
log_z;
|
||||
return std::exp(-error);
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
||||
|
||||
auto discreteFactor =
|
||||
|
|
|
|||
Loading…
Reference in New Issue