Avoid using slow conditionals()

release/4.3a0
Frank Dellaert 2024-10-16 19:53:36 +09:00
parent 6c9b25c45e
commit 1fe09f5e09
2 changed files with 35 additions and 16 deletions

View File

@ -197,11 +197,11 @@ void HybridGaussianConditional::print(const std::string &s,
std::cout << std::endl
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
<< std::endl;
conditionals().print(
factors().print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
[&](const GaussianFactorValuePair &pair) -> std::string {
RedirectCout rd;
if (gf && !gf->empty()) {
if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
gf->print("", formatter);
return rd.str();
} else {
@ -249,12 +249,18 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::FactorValuePairs likelihoods(
conditionals(),
[&](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
return {likelihood_m, Cgm_Kgcm};
factors(),
[&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
if (auto conditional =
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
const auto likelihood_m = conditional->likelihood(given);
// scalar is already correct.
assert(pair.second ==
conditional->negLogConstant() - negLogConstant_);
return {likelihood_m, pair.second};
} else {
return {nullptr, std::numeric_limits<double>::infinity()};
}
});
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
likelihoods);
@ -283,15 +289,19 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
auto pruner =
[&](const Assignment<Key> &choices,
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
};
auto pruned_conditionals = conditionals().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
pruned_conditionals);
FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::shared_ptr<HybridGaussianConditional>(
new HybridGaussianConditional(discreteKeys(), nrFrontals_,
prunedConditionals, negLogConstant_));
}
/* *******************************************************************************/

View File

@ -191,6 +191,7 @@ class GTSAM_EXPORT HybridGaussianConditional
const VectorValues &given) const;
/// Get Conditionals DecisionTree (dynamic cast from factors)
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
const Conditionals conditionals() const;
/**
@ -229,6 +230,14 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper);
/// Private constructor used when constants have already been calculated.
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
const FactorValuePairs &factors,
double negLogConstant)
: BaseFactor(discreteKeys, factors),
BaseConditional(nrFrontals),
negLogConstant_(negLogConstant) {}
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;