From 1f11b5472f957de00e614242fbc3e00c578125c3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 26 Sep 2024 10:41:04 -0700 Subject: [PATCH] Initial attempt at removing extra arguments --- gtsam/hybrid/HybridGaussianConditional.cpp | 85 ++++++++++++---------- gtsam/hybrid/HybridGaussianConditional.h | 9 +-- gtsam/hybrid/HybridGaussianFactor.h | 2 +- 3 files changed, 49 insertions(+), 47 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index ef38237f2..00ead068a 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -28,56 +28,65 @@ #include namespace gtsam { -HybridGaussianFactor::FactorValuePairs GetFactorValuePairs( - const HybridGaussianConditional::Conditionals &conditionals) { - auto func = [](const GaussianConditional::shared_ptr &conditional) - -> GaussianFactorValuePair { - double value = 0.0; - // Check if conditional is pruned - if (conditional) { - // Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|)) - value = conditional->negLogConstant(); - } - return {std::dynamic_pointer_cast(conditional), value}; - }; - return HybridGaussianFactor::FactorValuePairs(conditionals, func); -} - HybridGaussianConditional::HybridGaussianConditional( - const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const HybridGaussianConditional::Conditionals &conditionals) - : BaseFactor(CollectKeys(continuousFrontals, continuousParents), - discreteParents, GetFactorValuePairs(conditionals)), - BaseConditional(continuousFrontals.size()), - conditionals_(conditionals) { - // Calculate negLogConstant_ as the minimum of the negative-log normalizers of - // the conditionals, by visiting the decision tree: + : BaseConditional(0) { // Initialize with zero; we'll set it properly later + + // Check if conditionals are empty + if (conditionals.empty()) { + throw std::invalid_argument("Conditionals cannot be empty"); + } + + KeyVector frontals, parents; negLogConstant_ = std::numeric_limits::infinity(); - conditionals_.visit( - [this](const GaussianConditional::shared_ptr &conditional) { - if (conditional) { - this->negLogConstant_ = - std::min(this->negLogConstant_, conditional->negLogConstant()); - } - }); + auto func = + [&](const GaussianConditional::shared_ptr &c) -> GaussianFactorValuePair { + double value = 0.0; + // Check if conditional is pruned + if (c) { + KeyVector cf(c->frontals().begin(), c->frontals().end()); + KeyVector cp(c->parents().begin(), c->parents().end()); + if (frontals.empty()) { + // Get frontal/parent keys from first conditional. + frontals = cf; + parents = cp; + } else if (cf != frontals || cp != parents) { + throw std::invalid_argument( + "All conditionals must have the same frontals and parents"); + } + // Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|)) + value = c->negLogConstant(); + this->negLogConstant_ = std::min(this->negLogConstant_, value); + } + return {std::dynamic_pointer_cast(c), value}; + }; + BaseFactor::factors_ = HybridGaussianFactor::Factors(conditionals, func); + + // Initialize base classes + KeyVector continuousKeys = frontals; + continuousKeys.insert(continuousKeys.end(), parents.begin(), parents.end()); + BaseFactor::keys_ = continuousKeys; + BaseFactor::discreteKeys_ = discreteParents; + BaseConditional::nrFrontals_ = frontals.size(); + + // Assign conditionals + conditionals_ = conditionals; // TODO(frank): a duplicate of factors_ !!! } +/* *******************************************************************************/ +HybridGaussianConditional::HybridGaussianConditional( + const DiscreteKey &discreteParent, + const std::vector &conditionals) + : HybridGaussianConditional(DiscreteKeys{discreteParent}, + Conditionals({discreteParent}, conditionals)) {} + /* *******************************************************************************/ const HybridGaussianConditional::Conditionals & HybridGaussianConditional::conditionals() const { return conditionals_; } -/* *******************************************************************************/ -HybridGaussianConditional::HybridGaussianConditional( - const KeyVector &continuousFrontals, const KeyVector &continuousParents, - const DiscreteKey &discreteParent, - const std::vector &conditionals) - : HybridGaussianConditional(continuousFrontals, continuousParents, - DiscreteKeys{discreteParent}, - Conditionals({discreteParent}, conditionals)) {} - /* *******************************************************************************/ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() const { diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 9c70cf6cb..6198f65fb 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -95,17 +95,13 @@ class GTSAM_EXPORT HybridGaussianConditional /** * @brief Construct a new HybridGaussianConditional object. * - * @param continuousFrontals the continuous frontals. - * @param continuousParents the continuous parents. * @param discreteParents the discrete parents. Will be placed last. * @param conditionals a decision tree of GaussianConditionals. The number of * conditionals should be C^(number of discrete parents), where C is the * cardinality of the DiscreteKeys in discreteParents, since the * discreteParents will be used as the labels in the decision tree. */ - HybridGaussianConditional(const KeyVector &continuousFrontals, - const KeyVector &continuousParents, - const DiscreteKeys &discreteParents, + HybridGaussianConditional(const DiscreteKeys &discreteParents, const Conditionals &conditionals); /** @@ -113,14 +109,11 @@ class GTSAM_EXPORT HybridGaussianConditional * a vector of Gaussian conditionals. * The DecisionTree-based constructor is preferred over this one. * - * @param continuousFrontals The continuous frontal variables - * @param continuousParents The continuous parent variables * @param discreteParent Single discrete parent variable * @param conditionals Vector of conditionals with the same size as the * cardinality of the discrete parent. */ HybridGaussianConditional( - const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKey &discreteParent, const std::vector &conditionals); diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 817e54e56..1aa6a0263 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// typedef for Decision Tree of Gaussian factors. using Factors = DecisionTree; - private: + protected: /// Decision tree of Gaussian factors indexed by discrete keys. Factors factors_;