Avoid using slow conditionals()
parent
6c9b25c45e
commit
1fe09f5e09
|
@ -197,11 +197,11 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
std::cout << std::endl
|
std::cout << std::endl
|
||||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals().print(
|
factors().print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
[&](const GaussianFactorValuePair &pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
if (gf && !gf->empty()) {
|
if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||||
gf->print("", formatter);
|
gf->print("", formatter);
|
||||||
return rd.str();
|
return rd.str();
|
||||||
} else {
|
} else {
|
||||||
|
@ -249,12 +249,18 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||||
conditionals(),
|
factors(),
|
||||||
[&](const GaussianConditional::shared_ptr &conditional)
|
[&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
-> GaussianFactorValuePair {
|
if (auto conditional =
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||||
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
return {likelihood_m, Cgm_Kgcm};
|
// scalar is already correct.
|
||||||
|
assert(pair.second ==
|
||||||
|
conditional->negLogConstant() - negLogConstant_);
|
||||||
|
return {likelihood_m, pair.second};
|
||||||
|
} else {
|
||||||
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
|
}
|
||||||
});
|
});
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
||||||
likelihoods);
|
likelihoods);
|
||||||
|
@ -283,15 +289,19 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
|
|
||||||
// Check the max value for every combination of our keys.
|
// Check the max value for every combination of our keys.
|
||||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
auto pruner = [&](const Assignment<Key> &choices,
|
auto pruner =
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
[&](const Assignment<Key> &choices,
|
||||||
-> GaussianConditional::shared_ptr {
|
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
if (max->evaluate(choices) == 0.0)
|
||||||
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
|
else
|
||||||
|
return pair;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals().apply(pruner);
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
return std::shared_ptr<HybridGaussianConditional>(
|
||||||
pruned_conditionals);
|
new HybridGaussianConditional(discreteKeys(), nrFrontals_,
|
||||||
|
prunedConditionals, negLogConstant_));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -191,6 +191,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
const VectorValues &given) const;
|
const VectorValues &given) const;
|
||||||
|
|
||||||
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
||||||
|
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
|
||||||
const Conditionals conditionals() const;
|
const Conditionals conditionals() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -229,6 +230,14 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Helper &helper);
|
const Helper &helper);
|
||||||
|
|
||||||
|
/// Private constructor used when constants have already been calculated.
|
||||||
|
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
|
||||||
|
const FactorValuePairs &factors,
|
||||||
|
double negLogConstant)
|
||||||
|
: BaseFactor(discreteKeys, factors),
|
||||||
|
BaseConditional(nrFrontals),
|
||||||
|
negLogConstant_(negLogConstant) {}
|
||||||
|
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue