helper method to reduce code duplication
parent
4c74ec113a
commit
cbb0a30173
|
@ -33,6 +33,19 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
GaussianConditional::shared_ptr checkConditional(
|
||||||
|
const GaussianFactor::shared_ptr &factor) {
|
||||||
|
if (auto conditional =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(factor)) {
|
||||||
|
return conditional;
|
||||||
|
} else {
|
||||||
|
throw std::logic_error(
|
||||||
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
/**
|
/**
|
||||||
* @brief Helper struct for constructing HybridGaussianConditional objects
|
* @brief Helper struct for constructing HybridGaussianConditional objects
|
||||||
|
@ -92,10 +105,7 @@ struct HybridGaussianConditional::Helper {
|
||||||
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
||||||
auto func = [this](const GaussianFactorValuePair &pair) {
|
auto func = [this](const GaussianFactorValuePair &pair) {
|
||||||
if (!pair.first) return;
|
if (!pair.first) return;
|
||||||
auto gc = std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
auto gc = checkConditional(pair.first);
|
||||||
if (!gc)
|
|
||||||
throw std::runtime_error(
|
|
||||||
"HybridGaussianConditional called with non-conditional.");
|
|
||||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
||||||
};
|
};
|
||||||
|
@ -179,14 +189,11 @@ size_t HybridGaussianConditional::nrComponents() const {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues &discreteValues) const {
|
||||||
auto &[ptr, _] = factors()(discreteValues);
|
auto &[factor, _] = factors()(discreteValues);
|
||||||
if (!ptr) return nullptr;
|
if (!factor) return nullptr;
|
||||||
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
|
||||||
if (conditional)
|
auto conditional = checkConditional(factor);
|
||||||
return conditional;
|
return conditional;
|
||||||
else
|
|
||||||
throw std::logic_error(
|
|
||||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -320,28 +327,23 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
};
|
};
|
||||||
|
|
||||||
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(), prunedConditionals);
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
|
prunedConditionals);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::logProbability(
|
double HybridGaussianConditional::logProbability(
|
||||||
const HybridValues &values) const {
|
const HybridValues &values) const {
|
||||||
auto [factor, _] = factors()(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
auto conditional = checkConditional(factor);
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
else
|
|
||||||
throw std::logic_error(
|
|
||||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
||||||
auto [factor, _] = factors()(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
|
auto conditional = checkConditional(factor);
|
||||||
return conditional->evaluate(values.continuous());
|
return conditional->evaluate(values.continuous());
|
||||||
else
|
|
||||||
throw std::logic_error(
|
|
||||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue