Fix initialization issue with helper class

release/4.3a0
Frank Dellaert 2024-09-26 11:43:05 -07:00
parent b45ba003ca
commit ebebf7ddd5
2 changed files with 36 additions and 30 deletions

View File

@ -28,57 +28,43 @@
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
HybridGaussianConditional::HybridGaussianConditional( /* *******************************************************************************/
const DiscreteKeys &discreteParents, HybridGaussianConditional::ConstructorHelper::ConstructorHelper(
const HybridGaussianConditional::Conditionals &conditionals) const HybridGaussianConditional::Conditionals &conditionals) {
: BaseConditional(0) { // Initialize with zero; we'll set it properly later negLogConstant = std::numeric_limits<double>::infinity();
// Check if conditionals are empty
if (conditionals.empty()) {
throw std::invalid_argument("Conditionals cannot be empty");
}
KeyVector frontals, parents;
negLogConstant_ = std::numeric_limits<double>::infinity();
auto func = auto func =
[&](const GaussianConditional::shared_ptr &c) -> GaussianFactorValuePair { [&](const GaussianConditional::shared_ptr &c) -> GaussianFactorValuePair {
double value = 0.0; double value = 0.0;
// Check if conditional is pruned
if (c) { if (c) {
KeyVector cf(c->frontals().begin(), c->frontals().end()); KeyVector cf(c->frontals().begin(), c->frontals().end());
KeyVector cp(c->parents().begin(), c->parents().end()); KeyVector cp(c->parents().begin(), c->parents().end());
if (frontals.empty()) { if (frontals.empty()) {
// Get frontal/parent keys from first conditional.
frontals = cf; frontals = cf;
parents = cp; parents = cp;
} else if (cf != frontals || cp != parents) { } else if (cf != frontals || cp != parents) {
throw std::invalid_argument( throw std::invalid_argument(
"All conditionals must have the same frontals and parents"); "All conditionals must have the same frontals and parents");
} }
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
value = c->negLogConstant(); value = c->negLogConstant();
this->negLogConstant_ = std::min(this->negLogConstant_, value); negLogConstant = std::min(negLogConstant, value);
} }
return {std::dynamic_pointer_cast<GaussianFactor>(c), value}; return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
}; };
HybridGaussianFactor::FactorValuePairs pairs(conditionals, func); pairs = HybridGaussianFactor::FactorValuePairs(conditionals, func);
// Adjust frontals size // Build continuousKeys
BaseConditional::nrFrontals_ = frontals.size(); continuousKeys = frontals;
continuousKeys.insert(continuousKeys.end(), parents.begin(), parents.end());
// Initialize HybridFactor
HybridFactor::category_ = HybridFactor::Category::Hybrid;
HybridFactor::discreteKeys_ = discreteParents;
HybridFactor::keys_ = frontals;
keys_.insert(keys_.end(), parents.begin(), parents.end());
// Initialize BaseFactor
BaseFactor::factors_ = BaseFactor::augment(pairs); // TODO(frank): expensive
// Assign local conditionals. TODO(frank): a duplicate of factors_ !!!
conditionals_ = conditionals;
} }
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
: HybridGaussianConditional(discreteParents, conditionals,
ConstructorHelper(conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &discreteParent, const DiscreteKey &discreteParent,

View File

@ -68,6 +68,26 @@ class GTSAM_EXPORT HybridGaussianConditional
///< Take advantage of the neg-log space so everything is a minimization ///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_; double negLogConstant_;
/// Helper struct for private constructor.
struct ConstructorHelper {
KeyVector frontals;
KeyVector parents;
KeyVector continuousKeys;
HybridGaussianFactor::FactorValuePairs pairs;
double negLogConstant;
ConstructorHelper(const Conditionals &conditionals);
};
/// Private constructor
HybridGaussianConditional(
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals,
const ConstructorHelper &helper)
: BaseFactor(helper.continuousKeys, discreteParents, helper.pairs),
BaseConditional(helper.frontals.size()),
conditionals_(conditionals),
negLogConstant_(helper.negLogConstant) {}
/** /**
* @brief Convert a HybridGaussianConditional of conditionals into * @brief Convert a HybridGaussianConditional of conditionals into
* a DecisionTree of Gaussian factor graphs. * a DecisionTree of Gaussian factor graphs.