Avoid using slow conditionals()
parent
6c9b25c45e
commit
1fe09f5e09
|
@ -197,11 +197,11 @@ void HybridGaussianConditional::print(const std::string &s,
|
|||
std::cout << std::endl
|
||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||
<< std::endl;
|
||||
conditionals().print(
|
||||
factors().print(
|
||||
"", [&](Key k) { return formatter(k); },
|
||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
||||
[&](const GaussianFactorValuePair &pair) -> std::string {
|
||||
RedirectCout rd;
|
||||
if (gf && !gf->empty()) {
|
||||
if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||
gf->print("", formatter);
|
||||
return rd.str();
|
||||
} else {
|
||||
|
@ -249,12 +249,18 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
|||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||
conditionals(),
|
||||
[&](const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianFactorValuePair {
|
||||
factors(),
|
||||
[&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||
if (auto conditional =
|
||||
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||
const auto likelihood_m = conditional->likelihood(given);
|
||||
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
|
||||
return {likelihood_m, Cgm_Kgcm};
|
||||
// scalar is already correct.
|
||||
assert(pair.second ==
|
||||
conditional->negLogConstant() - negLogConstant_);
|
||||
return {likelihood_m, pair.second};
|
||||
} else {
|
||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||
}
|
||||
});
|
||||
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
||||
likelihoods);
|
||||
|
@ -283,15 +289,19 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
|||
|
||||
// Check the max value for every combination of our keys.
|
||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
||||
auto pruner =
|
||||
[&](const Assignment<Key> &choices,
|
||||
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||
if (max->evaluate(choices) == 0.0)
|
||||
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||
else
|
||||
return pair;
|
||||
};
|
||||
|
||||
auto pruned_conditionals = conditionals().apply(pruner);
|
||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||
pruned_conditionals);
|
||||
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||
return std::shared_ptr<HybridGaussianConditional>(
|
||||
new HybridGaussianConditional(discreteKeys(), nrFrontals_,
|
||||
prunedConditionals, negLogConstant_));
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -191,6 +191,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
const VectorValues &given) const;
|
||||
|
||||
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
||||
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
|
||||
const Conditionals conditionals() const;
|
||||
|
||||
/**
|
||||
|
@ -229,6 +230,14 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||
const Helper &helper);
|
||||
|
||||
/// Private constructor used when constants have already been calculated.
|
||||
HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals,
|
||||
const FactorValuePairs &factors,
|
||||
double negLogConstant)
|
||||
: BaseFactor(discreteKeys, factors),
|
||||
BaseConditional(nrFrontals),
|
||||
negLogConstant_(negLogConstant) {}
|
||||
|
||||
/// Check whether `given` has values for all frontal keys.
|
||||
bool allFrontalsGiven(const VectorValues &given) const;
|
||||
|
||||
|
|
Loading…
Reference in New Issue