Major improvement -> no need to copy keys anymore

release/4.3a0
Frank Dellaert 2024-09-27 06:49:36 -07:00
parent bc555aef0a
commit 2b26b3c971
1 changed files with 17 additions and 12 deletions

View File

@ -27,30 +27,35 @@
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <cstddef>
namespace gtsam { namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
struct HybridGaussianConditional::ConstructorHelper { struct HybridGaussianConditional::ConstructorHelper {
KeyVector frontals, parents; std::optional<size_t> nrFrontals;
HybridGaussianFactor::FactorValuePairs pairs; HybridGaussianFactor::FactorValuePairs pairs;
double negLogConstant; double minNegLogConstant;
/// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals) {
negLogConstant = std::numeric_limits<double>::infinity();
auto func = [&](const GaussianConditional::shared_ptr &c) /// Compute all variables needed for the private constructor below.
ConstructorHelper(const Conditionals &conditionals)
: minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GaussianConditional::shared_ptr &c)
-> GaussianFactorValuePair { -> GaussianFactorValuePair {
double value = 0.0; double value = 0.0;
if (c) { if (c) {
if (frontals.empty()) { if (!nrFrontals.has_value()) {
frontals = KeyVector(c->frontals().begin(), c->frontals().end()); nrFrontals = c->nrFrontals();
parents = KeyVector(c->parents().begin(), c->parents().end());
} }
value = c->negLogConstant(); value = c->negLogConstant();
negLogConstant = std::min(negLogConstant, value); minNegLogConstant = std::min(minNegLogConstant, value);
} }
return {std::dynamic_pointer_cast<GaussianFactor>(c), value}; return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
}; };
pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func); pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
if (!nrFrontals.has_value()) {
throw std::runtime_error(
"HybridGaussianConditional: need at least one frontal variable.");
}
} }
}; };
@ -60,9 +65,9 @@ HybridGaussianConditional::HybridGaussianConditional(
const HybridGaussianConditional::Conditionals &conditionals, const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper) const ConstructorHelper &helper)
: BaseFactor(discreteParents, helper.pairs), : BaseFactor(discreteParents, helper.pairs),
BaseConditional(helper.frontals.size()), BaseConditional(*helper.nrFrontals),
conditionals_(conditionals), conditionals_(conditionals),
negLogConstant_(helper.negLogConstant) {} negLogConstant_(helper.minNegLogConstant) {}
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,