helper method to reduce code duplication

release/4.3a0
Varun Agrawal 2024-10-23 12:45:38 -04:00
parent 4c74ec113a
commit cbb0a30173
1 changed files with 25 additions and 23 deletions

View File

@ -33,6 +33,19 @@
#include <memory>
namespace gtsam {
/* *******************************************************************************/
GaussianConditional::shared_ptr checkConditional(
const GaussianFactor::shared_ptr &factor) {
if (auto conditional =
std::dynamic_pointer_cast<GaussianConditional>(factor)) {
return conditional;
} else {
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
}
}
/* *******************************************************************************/
/**
* @brief Helper struct for constructing HybridGaussianConditional objects
@ -92,10 +105,7 @@ struct HybridGaussianConditional::Helper {
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
auto func = [this](const GaussianFactorValuePair &pair) {
if (!pair.first) return;
auto gc = std::dynamic_pointer_cast<GaussianConditional>(pair.first);
if (!gc)
throw std::runtime_error(
"HybridGaussianConditional called with non-conditional.");
auto gc = checkConditional(pair.first);
if (!nrFrontals) nrFrontals = gc->nrFrontals();
minNegLogConstant = std::min(minNegLogConstant, pair.second);
};
@ -179,14 +189,11 @@ size_t HybridGaussianConditional::nrComponents() const {
/* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues &discreteValues) const {
auto &[ptr, _] = factors()(discreteValues);
if (!ptr) return nullptr;
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional)
return conditional;
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
auto &[factor, _] = factors()(discreteValues);
if (!factor) return nullptr;
auto conditional = checkConditional(factor);
return conditional;
}
/* *******************************************************************************/
@ -320,28 +327,23 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
};
FactorValuePairs prunedConditionals = factors().apply(pruner);
return std::make_shared<HybridGaussianConditional>(discreteKeys(), prunedConditionals);
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
prunedConditionals);
}
/* *******************************************************************************/
double HybridGaussianConditional::logProbability(
const HybridValues &values) const {
auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->logProbability(values.continuous());
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
auto conditional = checkConditional(factor);
return conditional->logProbability(values.continuous());
}
/* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
auto [factor, _] = factors()(values.discrete());
if (auto conditional = std::dynamic_pointer_cast<GaussianConditional>(factor))
return conditional->evaluate(values.continuous());
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
auto conditional = checkConditional(factor);
return conditional->evaluate(values.continuous());
}
} // namespace gtsam