Avoid using slow conditionals()

release/4.3a0
Frank Dellaert 2024-10-16 19:53:36 +09:00
parent 6c9b25c45e
commit 1fe09f5e09
2 changed files with 35 additions and 16 deletions

View File

@ -197,11 +197,11 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << std::endl std::cout << std::endl
<< " logNormalizationConstant: " << -negLogConstant() << std::endl << " logNormalizationConstant: " << -negLogConstant() << std::endl
<< std::endl; << std::endl;
conditionals().print( factors().print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string { [&](const GaussianFactorValuePair &pair) -> std::string {
RedirectCout rd; RedirectCout rd;
if (gf && !gf->empty()) { if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
gf->print("", formatter); gf->print("", formatter);
return rd.str(); return rd.str();
} else { } else {
@ -249,12 +249,18 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::FactorValuePairs likelihoods( const HybridGaussianFactor::FactorValuePairs likelihoods(
conditionals(), factors(),
[&](const GaussianConditional::shared_ptr &conditional) [&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
-> GaussianFactorValuePair { if (auto conditional =
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
const auto likelihood_m = conditional->likelihood(given); const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_; // scalar is already correct.
return {likelihood_m, Cgm_Kgcm}; assert(pair.second ==
conditional->negLogConstant() - negLogConstant_);
return {likelihood_m, pair.second};
} else {
return {nullptr, std::numeric_limits<double>::infinity()};
}
}); });
return std::make_shared<HybridGaussianFactor>(discreteParentKeys, return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods); likelihoods);
@ -283,15 +289,19 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
// Check the max value for every combination of our keys. // Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional. // If the max value is 0.0, we can prune the corresponding conditional.
auto pruner = [&](const Assignment<Key> &choices, auto pruner =
const GaussianConditional::shared_ptr &conditional) [&](const Assignment<Key> &choices,
-> GaussianConditional::shared_ptr { const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
return (max->evaluate(choices) == 0.0) ? nullptr : conditional; if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
}; };
auto pruned_conditionals = conditionals().apply(pruner); FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(), return std::shared_ptr<HybridGaussianConditional>(
pruned_conditionals); new HybridGaussianConditional(discreteKeys(), nrFrontals_,
prunedConditionals, negLogConstant_));
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -191,6 +191,7 @@ class GTSAM_EXPORT HybridGaussianConditional
const VectorValues &given) const; const VectorValues &given) const;
/// Get Conditionals DecisionTree (dynamic cast from factors) /// Get Conditionals DecisionTree (dynamic cast from factors)
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
const Conditionals conditionals() const; const Conditionals conditionals() const;
/** /**
@ -229,6 +230,14 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper); const Helper &helper);
/// Private constructor used when constants have already been calculated.
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
const FactorValuePairs &factors,
double negLogConstant)
: BaseFactor(discreteKeys, factors),
BaseConditional(nrFrontals),
negLogConstant_(negLogConstant) {}
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;