attempt to fix elimination

release/4.3a0
Frank Dellaert 2023-01-01 17:02:23 -05:00
parent 3d821ec22b
commit 0095f73130
2 changed files with 24 additions and 40 deletions

View File

@ -62,6 +62,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
// Note: constant is log of normalization constant for probabilities. // Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, // Errors is the negative log-likelihood,
// hence we subtract the constant here. // hence we subtract the constant here.
if (!factor) return 0.0; // If nullptr, return 0.0 error
return factor->error(values) - constant; return factor->error(values) - constant;
} }

View File

@ -199,14 +199,14 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
// Collect all the frontal factors to create Gaussian factor graphs // Collect all the factors to create a set of Gaussian factor graphs in a
// indexed on the discrete keys. // decision tree indexed by all discrete keys involved.
GaussianMixtureFactor::Sum sum = factors.SumFrontals(); GaussianMixtureFactor::Sum sum = factors.SumFrontals();
// If a tree leaf contains nullptr, // If a tree leaf contains nullptr, convert that leaf to an empty
// convert that leaf to an empty GaussianFactorGraph. // GaussianFactorGraph. Needed since the DecisionTree will otherwise create a
// Needed since the DecisionTree will otherwise create // GFG with a single (null) factor.
// a GFG with a single (null) factor. // TODO(dellaert): can SumFrontals not guarantee this?
auto emptyGaussian = auto emptyGaussian =
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) { [](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
bool hasNull = std::any_of( bool hasNull = std::any_of(
@ -222,7 +222,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; KeyVector keysOfSeparator;
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
@ -236,24 +235,21 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
gttic_(hybrid_eliminate); gttic_(hybrid_eliminate);
#endif #endif
std::pair<boost::shared_ptr<GaussianConditional>, boost::shared_ptr<GaussianConditional> conditional;
boost::shared_ptr<GaussianFactor>> boost::shared_ptr<GaussianFactor> newFactor;
conditional_factor = boost::tie(conditional, newFactor) =
EliminatePreferCholesky(graph_z.graph, frontalKeys); EliminatePreferCholesky(graph_z.graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the // TODO(dellaert): always the same, and we already computed this in caller?
// eliminated GaussianConditional keysOfSeparator = newFactor->keys();
keysOfEliminated = conditional_factor.first->keys();
keysOfSeparator = conditional_factor.second->keys();
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate); gttoc_(hybrid_eliminate);
#endif #endif
GaussianConditional::shared_ptr conditional = conditional_factor.first;
// Get the log of the log normalization constant inverse. // Get the log of the log normalization constant inverse.
double logZ = -conditional->logNormalizationConstant() + graph_z.constant; double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
return {conditional, {conditional_factor.second, logZ}}; return {conditional, {newFactor, logZ}};
}; };
// Perform elimination! // Perform elimination!
@ -266,44 +262,31 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults); auto pair = unzip(eliminationResults);
const auto &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto conditional = boost::make_shared<GaussianMixture>(
frontalKeys, keysOfSeparator, discreteSeparator, pair.first); frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
// If there are no more continuous parents, then we should create here a // If there are no more continuous parents, then we should create a
// DiscreteFactor, with the error for each discrete choice. // DiscreteFactor here, with the error for each discrete choice.
const auto &separatorFactors = pair.second;
if (keysOfSeparator.empty()) { if (keysOfSeparator.empty()) {
auto factorProb = auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
GaussianFactor::shared_ptr factor = factor_z.factor;
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
// This is the probability q(μ) at the MLE point. // This is the probability q(μ) at the MLE point.
double error = factor_z.error(VectorValues()); return factor_z.error(VectorValues());
return std::exp(-error);
}
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorProb); const DecisionTree<Key, double> fdt(separatorFactors, factorProb);
// Normalize the values of decision tree to be valid probabilities const auto discreteFactor =
double sum = 0.0;
auto visitor = [&](double y) { sum += y; };
fdt.visit(visitor);
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional), return {boost::make_shared<HybridConditional>(conditional),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)}; boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else { } else {
// Create a resulting GaussianMixtureFactor on the separator. // Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>( // Keys can be computed from the factors, so we should not need to pass them
// in.
const auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()), KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors); discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor}; return {boost::make_shared<HybridConditional>(conditional), factor};