diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 6e5fe93c4..76a09bcf5 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -29,6 +29,8 @@ #include #include +#include +#include "gtsam/linear/GaussianConditional.h" namespace gtsam { /* *******************************************************************************/ @@ -38,14 +40,13 @@ namespace gtsam { * This struct contains the following fields: * - nrFrontals: Optional size_t for number of frontal variables * - pairs: FactorValuePairs for storing conditionals with their negLogConstant - * - conditionals: Conditionals for storing conditionals. TODO(frank): kill! * - minNegLogConstant: minimum negLogConstant, computed here, subtracted in * constructor */ struct HybridGaussianConditional::Helper { - std::optional nrFrontals; FactorValuePairs pairs; - double minNegLogConstant; + std::optional nrFrontals = {}; + double minNegLogConstant = std::numeric_limits::infinity(); using GC = GaussianConditional; using P = std::vector>; @@ -54,8 +55,6 @@ struct HybridGaussianConditional::Helper { template explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) { nrFrontals = 1; - minNegLogConstant = std::numeric_limits::infinity(); - std::vector fvs; std::vector gcs; fvs.reserve(p.size()); @@ -73,8 +72,7 @@ struct HybridGaussianConditional::Helper { } /// Construct from tree of GaussianConditionals. - explicit Helper(const Conditionals &conditionals) - : minNegLogConstant(std::numeric_limits::infinity()) { + explicit Helper(const Conditionals &conditionals) { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair { if (!gc) return {nullptr, std::numeric_limits::infinity()}; if (!nrFrontals) nrFrontals = gc->nrFrontals(); @@ -89,6 +87,25 @@ struct HybridGaussianConditional::Helper { "Provided conditionals do not contain any frontal variables."); } } + + /// Construct from tree of factor/scalar pairs. + explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) { + auto func = [this](const GaussianFactorValuePair &pair) { + if (!pair.first) return; + auto gc = std::dynamic_pointer_cast(pair.first); + if (!gc) + throw std::runtime_error( + "HybridGaussianConditional called with non-conditional."); + if (!nrFrontals) nrFrontals = gc->nrFrontals(); + minNegLogConstant = std::min(minNegLogConstant, pair.second); + }; + pairs.visit(func); + if (!nrFrontals.has_value()) { + throw std::runtime_error( + "HybridGaussianConditional: need at least one frontal variable. " + "Provided conditionals do not contain any frontal variables."); + } + } }; /* *******************************************************************************/ @@ -138,6 +155,10 @@ HybridGaussianConditional::HybridGaussianConditional( const HybridGaussianConditional::Conditionals &conditionals) : HybridGaussianConditional(discreteParents, Helper(conditionals)) {} +HybridGaussianConditional::HybridGaussianConditional( + const DiscreteKeys &discreteParents, const FactorValuePairs &pairs) + : HybridGaussianConditional(discreteParents, Helper(pairs)) {} + /* *******************************************************************************/ const HybridGaussianConditional::Conditionals HybridGaussianConditional::conditionals() const { @@ -300,8 +321,7 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( FactorValuePairs prunedConditionals = factors().apply(pruner); return std::shared_ptr( - new HybridGaussianConditional(discreteKeys(), nrFrontals_, - prunedConditionals, negLogConstant_)); + new HybridGaussianConditional(discreteKeys(), prunedConditionals)); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 38b7b9795..c38d8733c 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -141,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional HybridGaussianConditional(const DiscreteKeys &discreteParents, const Conditionals &conditionals); + /** + * @brief Construct from multiple discrete keys M and a tree of + * factor/scalar pairs, where the scalar is assumed to be the + * the negative log constant for each assignment m, up to a constant. + * + * @note Will throw if factors are not actually conditionals. + * + * @param discreteParents the discrete parents. Will be placed last. + * @param conditionalPairs Decision tree of GaussianFactor/scalar pairs. + */ + HybridGaussianConditional(const DiscreteKeys &discreteParents, + const FactorValuePairs &pairs); + /// @} /// @name Testable /// @{ @@ -230,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianConditional HybridGaussianConditional(const DiscreteKeys &discreteParents, const Helper &helper); - /// Private constructor used when constants have already been calculated. - HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals, - const FactorValuePairs &factors, - double negLogConstant) - : BaseFactor(discreteKeys, factors), - BaseConditional(nrFrontals), - negLogConstant_(negLogConstant) {} - /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const;