diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 4266ace15..bda44bb9d 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -286,6 +287,10 @@ namespace gtsam { return branches_; } + std::vector& branches() { + return branches_; + } + /** add a branch: TODO merge into constructor */ void push_back(NodePtr&& node) { // allSame_ is restricted to leaf nodes in a decision tree @@ -482,8 +487,8 @@ namespace gtsam { /****************************************************************************/ // DecisionTree /****************************************************************************/ - template - DecisionTree::DecisionTree() {} + template + DecisionTree::DecisionTree() : root_(nullptr) {} template DecisionTree::DecisionTree(const NodePtr& root) : @@ -554,6 +559,36 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } + /****************************************************************************/ + template + DecisionTree::DecisionTree(const Unary& op, + DecisionTree&& other) noexcept + : root_(std::move(other.root_)) { + // Apply the unary operation directly to each leaf in the tree + if (root_) { + // Define a helper function to traverse and apply the operation + struct ApplyUnary { + const Unary& op; + void operator()(typename DecisionTree::NodePtr& node) const { + if (auto leaf = std::dynamic_pointer_cast(node)) { + // Apply the unary operation to the leaf's constant value + leaf->constant_ = op(leaf->constant_); + } else if (auto choice = std::dynamic_pointer_cast(node)) { + // Recurse into the choice branches + for (NodePtr& branch : choice->branches()) { + (*this)(branch); + } + } + } + }; + + ApplyUnary applyUnary{op}; + applyUnary(root_); + } + // Reset the other tree's root to nullptr to avoid dangling references + other.root_ = nullptr; + } + /****************************************************************************/ template template @@ -694,7 +729,7 @@ namespace gtsam { typename DecisionTree::NodePtr DecisionTree::create( It begin, It end, ValueIt beginY, ValueIt endY) { auto node = build(begin, end, beginY, endY); - if (auto choice = std::dynamic_pointer_cast(node)) { + if (auto choice = std::dynamic_pointer_cast(node)) { return Choice::Unique(choice); } else { return node; @@ -710,7 +745,7 @@ namespace gtsam { // If leaf, apply unary conversion "op" and create a unique leaf. using LXLeaf = typename DecisionTree::Leaf; - if (auto leaf = std::dynamic_pointer_cast(f)) { + if (auto leaf = std::dynamic_pointer_cast(f)) { return NodePtr(new Leaf(Y_of_X(leaf->constant()))); } @@ -951,11 +986,16 @@ namespace gtsam { return root_->equals(*other.root_); } + /****************************************************************************/ template const Y& DecisionTree::operator()(const Assignment& x) const { + if (root_ == nullptr) + throw std::invalid_argument( + "DecisionTree::operator() called on empty tree"); return root_->operator ()(x); } + /****************************************************************************/ template DecisionTree DecisionTree::apply(const Unary& op) const { // It is unclear what should happen if tree is empty: @@ -966,6 +1006,7 @@ namespace gtsam { return DecisionTree(root_->apply(op)); } + /****************************************************************************/ /// Apply unary operator with assignment template DecisionTree DecisionTree::apply( @@ -1049,6 +1090,18 @@ namespace gtsam { return ss.str(); } -/******************************************************************************/ + /******************************************************************************/ + template + template + std::pair, DecisionTree> DecisionTree::split( + std::function(const Y&)> AB_of_Y) const { + using AB = std::pair; + const DecisionTree ab(*this, AB_of_Y); + const DecisionTree a(ab, [](const AB& p) { return p.first; }); + const DecisionTree b(ab, [](const AB& p) { return p.second; }); + return {a, b}; + } + + /******************************************************************************/ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 6d8d86530..486f798e9 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -85,7 +85,7 @@ namespace gtsam { /** ------------------------ Node base class --------------------------- */ struct Node { - using Ptr = std::shared_ptr; + using Ptr = std::shared_ptr; #ifdef DT_DEBUG_MEMORY static int nrNodes; @@ -156,10 +156,10 @@ namespace gtsam { template static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); - /** Internal helper function to create from - * keys, cardinalities, and Y values. - * Calls `build` which builds thetree bottom-up, - * before we prune in a top-down fashion. + /** + * Internal helper function to create a tree from keys, cardinalities, and Y + * values. Calls `build` which builds the tree bottom-up, before we prune in + * a top-down fashion. */ template static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); @@ -228,6 +228,15 @@ namespace gtsam { DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1); + /** + * @brief Move constructor for DecisionTree. Very efficient as does not + * allocate anything, just changes in-place. But `other` is consumed. + * + * @param op The unary operation to apply to the moved DecisionTree. + * @param other The DecisionTree to move from, will be empty afterwards. + */ + DecisionTree(const Unary& op, DecisionTree&& other) noexcept; + /** * @brief Convert from a different value type. * @@ -239,7 +248,7 @@ namespace gtsam { DecisionTree(const DecisionTree& other, Func Y_of_X); /** - * @brief Convert from a different value type X to value type Y, also transate + * @brief Convert from a different value type X to value type Y, also translate * labels via map from type M to L. * * @tparam M Previous label type. @@ -406,6 +415,18 @@ namespace gtsam { const ValueFormatter& valueFormatter, bool showZero = true) const; + /** + * @brief Convert into two trees with value types A and B. + * + * @tparam A First new value type. + * @tparam B Second new value type. + * @param AB_of_Y Functor to convert from type X to std::pair. + * @return A pair of DecisionTrees with value types A and B respectively. + */ + template + std::pair, DecisionTree> split( + std::function(const Y&)> AB_of_Y) const; + /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c625e1ba6..526001b51 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -11,7 +11,7 @@ /* * @file testDecisionTree.cpp - * @brief Develop DecisionTree + * @brief DecisionTree unit tests * @author Frank Dellaert * @author Can Erdogan * @date Jan 30, 2012 @@ -108,6 +108,7 @@ struct DT : public DecisionTree { std::cout << s; Base::print("", keyFormatter, valueFormatter); } + /// Equality method customized to int node type bool equals(const Base& other, double tol = 1e-9) const { auto compare = [](const int& v, const int& w) { return v == w; }; @@ -271,6 +272,58 @@ TEST(DecisionTree, Example) { DOT(acnotb); } +/* ************************************************************************** */ +// Test that we can create two trees out of one, using a function that returns a pair. +TEST(DecisionTree, Split) { + // Create labels + string A("A"), B("B"); + + // Create a decision tree + DT original(A, DT(B, 1, 2), DT(B, 3, 4)); + + // Define a function that returns an int/bool pair + auto split_function = [](const int& value) -> std::pair { + return {value*3, value*3 % 2 == 0}; + }; + + // Split the original tree into two new trees + auto [la,lb] = original.split(split_function); + + // Check the first resulting tree + EXPECT_LONGS_EQUAL(3, la(Assignment{{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(6, la(Assignment{{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(9, la(Assignment{{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(12, la(Assignment{{A, 1}, {B, 1}})); + + // Check the second resulting tree + EXPECT(!lb(Assignment{{A, 0}, {B, 0}})); + EXPECT(lb(Assignment{{A, 0}, {B, 1}})); + EXPECT(!lb(Assignment{{A, 1}, {B, 0}})); + EXPECT(lb(Assignment{{A, 1}, {B, 1}})); +} + + +/* ************************************************************************** */ +// Test that we can create a tree by modifying an rvalue. +TEST(DecisionTree, Consume) { + // Create labels + string A("A"), B("B"); + + // Create a decision tree + DT original(A, DT(B, 1, 2), DT(B, 3, 4)); + + DT modified([](int i){return i*2;}, std::move(original)); + + // Check the first resulting tree + EXPECT_LONGS_EQUAL(2, modified(Assignment{{A, 0}, {B, 0}})); + EXPECT_LONGS_EQUAL(4, modified(Assignment{{A, 0}, {B, 1}})); + EXPECT_LONGS_EQUAL(6, modified(Assignment{{A, 1}, {B, 0}})); + EXPECT_LONGS_EQUAL(8, modified(Assignment{{A, 1}, {B, 1}})); + + // Check original was moved + EXPECT(original.root_ == nullptr); +} + /* ************************************************************************** */ // test Conversion of values bool bool_of_int(const int& y) { return y != 0; }; diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 58724163e..ac03bd3a3 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -25,12 +25,27 @@ #include #include #include +#include #include #include #include +#include namespace gtsam { + +/* *******************************************************************************/ +GaussianConditional::shared_ptr checkConditional( + const GaussianFactor::shared_ptr &factor) { + if (auto conditional = + std::dynamic_pointer_cast(factor)) { + return conditional; + } else { + throw std::logic_error( + "A HybridGaussianConditional unexpectedly contained a non-conditional"); + } +} + /* *******************************************************************************/ /** * @brief Helper struct for constructing HybridGaussianConditional objects @@ -38,15 +53,13 @@ namespace gtsam { * This struct contains the following fields: * - nrFrontals: Optional size_t for number of frontal variables * - pairs: FactorValuePairs for storing conditionals with their negLogConstant - * - conditionals: Conditionals for storing conditionals. TODO(frank): kill! * - minNegLogConstant: minimum negLogConstant, computed here, subtracted in * constructor */ struct HybridGaussianConditional::Helper { - std::optional nrFrontals; FactorValuePairs pairs; - Conditionals conditionals; - double minNegLogConstant; + std::optional nrFrontals = {}; + double minNegLogConstant = std::numeric_limits::infinity(); using GC = GaussianConditional; using P = std::vector>; @@ -55,8 +68,6 @@ struct HybridGaussianConditional::Helper { template explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) { nrFrontals = 1; - minNegLogConstant = std::numeric_limits::infinity(); - std::vector fvs; std::vector gcs; fvs.reserve(p.size()); @@ -70,14 +81,11 @@ struct HybridGaussianConditional::Helper { gcs.push_back(gaussianConditional); } - conditionals = Conditionals({mode}, gcs); pairs = FactorValuePairs({mode}, fvs); } /// Construct from tree of GaussianConditionals. - explicit Helper(const Conditionals &conditionals) - : conditionals(conditionals), - minNegLogConstant(std::numeric_limits::infinity()) { + explicit Helper(const Conditionals &conditionals) { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair { if (!gc) return {nullptr, std::numeric_limits::infinity()}; if (!nrFrontals) nrFrontals = gc->nrFrontals(); @@ -92,21 +100,36 @@ struct HybridGaussianConditional::Helper { "Provided conditionals do not contain any frontal variables."); } } + + /// Construct from tree of factor/scalar pairs. + explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) { + auto func = [this](const GaussianFactorValuePair &pair) { + if (!pair.first) return; + auto gc = checkConditional(pair.first); + if (!nrFrontals) nrFrontals = gc->nrFrontals(); + minNegLogConstant = std::min(minNegLogConstant, pair.second); + }; + pairs.visit(func); + if (!nrFrontals.has_value()) { + throw std::runtime_error( + "HybridGaussianConditional: need at least one frontal variable. " + "Provided conditionals do not contain any frontal variables."); + } + } }; /* *******************************************************************************/ HybridGaussianConditional::HybridGaussianConditional( - const DiscreteKeys &discreteParents, const Helper &helper) + const DiscreteKeys &discreteParents, Helper &&helper) : BaseFactor(discreteParents, - FactorValuePairs(helper.pairs, - [&](const GaussianFactorValuePair & - pair) { // subtract minNegLogConstant - return GaussianFactorValuePair{ - pair.first, - pair.second - helper.minNegLogConstant}; - })), + FactorValuePairs( + [&](const GaussianFactorValuePair + &pair) { // subtract minNegLogConstant + return GaussianFactorValuePair{ + pair.first, pair.second - helper.minNegLogConstant}; + }, + std::move(helper.pairs))), BaseConditional(*helper.nrFrontals), - conditionals_(helper.conditionals), negLogConstant_(helper.minNegLogConstant) {} HybridGaussianConditional::HybridGaussianConditional( @@ -142,17 +165,23 @@ HybridGaussianConditional::HybridGaussianConditional( const HybridGaussianConditional::Conditionals &conditionals) : HybridGaussianConditional(discreteParents, Helper(conditionals)) {} +HybridGaussianConditional::HybridGaussianConditional( + const DiscreteKeys &discreteParents, const FactorValuePairs &pairs) + : HybridGaussianConditional(discreteParents, Helper(pairs)) {} + /* *******************************************************************************/ -const HybridGaussianConditional::Conditionals & +const HybridGaussianConditional::Conditionals HybridGaussianConditional::conditionals() const { - return conditionals_; + return Conditionals(factors(), [](auto &&pair) { + return std::dynamic_pointer_cast(pair.first); + }); } /* *******************************************************************************/ size_t HybridGaussianConditional::nrComponents() const { size_t total = 0; - conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { - if (node) total += 1; + factors().visit([&total](auto &&node) { + if (node.first) total += 1; }); return total; } @@ -160,14 +189,11 @@ size_t HybridGaussianConditional::nrComponents() const { /* *******************************************************************************/ GaussianConditional::shared_ptr HybridGaussianConditional::choose( const DiscreteValues &discreteValues) const { - auto &ptr = conditionals_(discreteValues); - if (!ptr) return nullptr; - auto conditional = std::dynamic_pointer_cast(ptr); - if (conditional) - return conditional; - else - throw std::logic_error( - "A HybridGaussianConditional unexpectedly contained a non-conditional"); + auto &[factor, _] = factors()(discreteValues); + if (!factor) return nullptr; + + auto conditional = checkConditional(factor); + return conditional; } /* *******************************************************************************/ @@ -176,18 +202,16 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf, const This *e = dynamic_cast(&lf); if (e == nullptr) return false; - // This will return false if either conditionals_ is empty or e->conditionals_ - // is empty, but not if both are empty or both are not empty: - if (conditionals_.empty() ^ e->conditionals_.empty()) return false; - - // Check the base and the factors: - return BaseFactor::equals(*e, tol) && - conditionals_.equals(e->conditionals_, - [tol](const GaussianConditional::shared_ptr &f1, - const GaussianConditional::shared_ptr &f2) { - return (!f1 && !f2) || - (f1 && f2 && f1->equals(*f2, tol)); - }); + // Factors existence and scalar values are checked in BaseFactor::equals. + // Here we check additionally that the factors *are* conditionals + // and are equal. + auto compareFunc = [tol](const GaussianFactorValuePair &pair1, + const GaussianFactorValuePair &pair2) { + auto c1 = std::dynamic_pointer_cast(pair1.first), + c2 = std::dynamic_pointer_cast(pair2.first); + return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol)); + }; + return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc); } /* *******************************************************************************/ @@ -202,11 +226,12 @@ void HybridGaussianConditional::print(const std::string &s, std::cout << std::endl << " logNormalizationConstant: " << -negLogConstant() << std::endl << std::endl; - conditionals_.print( + factors().print( "", [&](Key k) { return formatter(k); }, - [&](const GaussianConditional::shared_ptr &gf) -> std::string { + [&](const GaussianFactorValuePair &pair) -> std::string { RedirectCout rd; - if (gf && !gf->empty()) { + if (auto gf = + std::dynamic_pointer_cast(pair.first)) { gf->print("", formatter); return rd.str(); } else { @@ -254,12 +279,16 @@ std::shared_ptr HybridGaussianConditional::likelihood( const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const HybridGaussianFactor::FactorValuePairs likelihoods( - conditionals_, - [&](const GaussianConditional::shared_ptr &conditional) - -> GaussianFactorValuePair { - const auto likelihood_m = conditional->likelihood(given); - const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_; - return {likelihood_m, Cgm_Kgcm}; + factors(), + [&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { + if (auto conditional = + std::dynamic_pointer_cast(pair.first)) { + const auto likelihood_m = conditional->likelihood(given); + // pair.second == conditional->negLogConstant() - negLogConstant_ + return {likelihood_m, pair.second}; + } else { + return {nullptr, std::numeric_limits::infinity()}; + } }); return std::make_shared(discreteParentKeys, likelihoods); @@ -288,27 +317,32 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( // Check the max value for every combination of our keys. // If the max value is 0.0, we can prune the corresponding conditional. - auto pruner = [&](const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - return (max->evaluate(choices) == 0.0) ? nullptr : conditional; + auto pruner = + [&](const Assignment &choices, + const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { + if (max->evaluate(choices) == 0.0) + return {nullptr, std::numeric_limits::infinity()}; + else + return pair; }; - auto pruned_conditionals = conditionals_.apply(pruner); + FactorValuePairs prunedConditionals = factors().apply(pruner); return std::make_shared(discreteKeys(), - pruned_conditionals); + prunedConditionals); } /* *******************************************************************************/ double HybridGaussianConditional::logProbability( const HybridValues &values) const { - auto conditional = conditionals_(values.discrete()); + auto [factor, _] = factors()(values.discrete()); + auto conditional = checkConditional(factor); return conditional->logProbability(values.continuous()); } /* *******************************************************************************/ double HybridGaussianConditional::evaluate(const HybridValues &values) const { - auto conditional = conditionals_(values.discrete()); + auto [factor, _] = factors()(values.discrete()); + auto conditional = checkConditional(factor); return conditional->evaluate(values.continuous()); } diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 4cc3d3196..c485fafce 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -64,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional using Conditionals = DecisionTree; private: - Conditionals conditionals_; ///< a decision tree of Gaussian conditionals. - ///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))). ///< Take advantage of the neg-log space so everything is a minimization double negLogConstant_; @@ -143,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional HybridGaussianConditional(const DiscreteKeys &discreteParents, const Conditionals &conditionals); + /** + * @brief Construct from multiple discrete keys M and a tree of + * factor/scalar pairs, where the scalar is assumed to be the + * the negative log constant for each assignment m, up to a constant. + * + * @note Will throw if factors are not actually conditionals. + * + * @param discreteParents the discrete parents. Will be placed last. + * @param conditionalPairs Decision tree of GaussianFactor/scalar pairs. + */ + HybridGaussianConditional(const DiscreteKeys &discreteParents, + const FactorValuePairs &pairs); + /// @} /// @name Testable /// @{ @@ -192,8 +203,9 @@ class GTSAM_EXPORT HybridGaussianConditional std::shared_ptr likelihood( const VectorValues &given) const; - /// Getter for the underlying Conditionals DecisionTree - const Conditionals &conditionals() const; + /// Get Conditionals DecisionTree (dynamic cast from factors) + /// @note Slow: avoid using in favor of factors(), which uses existing tree. + const Conditionals conditionals() const; /** * @brief Compute the logProbability of this hybrid Gaussian conditional. @@ -229,7 +241,7 @@ class GTSAM_EXPORT HybridGaussianConditional /// Private constructor that uses helper struct above. HybridGaussianConditional(const DiscreteKeys &discreteParents, - const Helper &helper); + Helper &&helper); /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; @@ -241,7 +253,6 @@ class GTSAM_EXPORT HybridGaussianConditional void serialize(Archive &ar, const unsigned int /*version*/) { ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); - ar &BOOST_SERIALIZATION_NVP(conditionals_); } #endif }; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8e6123f10..ceabe0871 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -48,8 +49,6 @@ #include #include -#include "gtsam/discrete/DecisionTreeFactor.h" - namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -57,10 +56,20 @@ template class EliminateableFactorGraph; using std::dynamic_pointer_cast; using OrphanWrapper = BayesTreeOrphanWrapper; -using Result = - std::pair, GaussianFactor::shared_ptr>; -using ResultValuePair = std::pair; -using ResultTree = DecisionTree; + +/// Result from elimination. +struct Result { + GaussianConditional::shared_ptr conditional; + double negLogK; + GaussianFactor::shared_ptr factor; + double scalar; + + bool operator==(const Result &other) const { + return conditional == other.conditional && negLogK == other.negLogK && + factor == other.factor && scalar == other.scalar; + } +}; +using ResultTree = DecisionTree; static const VectorValues kEmpty; @@ -294,17 +303,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors, static std::shared_ptr createDiscreteFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { - auto calculateError = [&](const auto &pair) -> double { - const auto &[conditional, factor] = pair.first; - const double scalar = pair.second; - if (conditional && factor) { + auto calculateError = [&](const Result &result) -> double { + if (result.conditional && result.factor) { // `error` has the following contributions: // - the scalar is the sum of all mode-dependent constants // - factor->error(kempty) is the error remaining after elimination // - negLogK is what is given to the conditional to normalize - const double negLogK = conditional->negLogConstant(); - return scalar + factor->error(kEmpty) - negLogK; - } else if (!conditional && !factor) { + return result.scalar + result.factor->error(kEmpty) - result.negLogK; + } else if (!result.conditional && !result.factor) { // If the factor has been pruned, return infinite error return std::numeric_limits::infinity(); } else { @@ -323,13 +329,10 @@ static std::shared_ptr createHybridGaussianFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { // Correct for the normalization constant used up by the conditional - auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair { - const auto &[conditional, factor] = pair.first; - const double scalar = pair.second; - if (conditional && factor) { - const double negLogK = conditional->negLogConstant(); - return {factor, scalar - negLogK}; - } else if (!conditional && !factor) { + auto correct = [&](const Result &result) -> GaussianFactorValuePair { + if (result.conditional && result.factor) { + return {result.factor, result.scalar - result.negLogK}; + } else if (!result.conditional && !result.factor) { return {nullptr, std::numeric_limits::infinity()}; } else { throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); @@ -363,34 +366,34 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // any difference in noise models used. HybridGaussianProductFactor productFactor = collectProductFactor(); - // Convert factor graphs with a nullptr to an empty factor graph. - // This is done after assembly since it is non-trivial to keep track of which - // FG has a nullptr as we're looping over the factors. - auto prunedProductFactor = productFactor.removeEmpty(); + // Check if a factor is null + auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; // This is the elimination method on the leaf nodes bool someContinuousLeft = false; - auto eliminate = [&](const std::pair &pair) - -> std::pair { + auto eliminate = + [&](const std::pair &pair) -> Result { const auto &[graph, scalar] = pair; - if (graph.empty()) { - return {{nullptr, nullptr}, 0.0}; + // If any product contains a pruned factor, prune it here. Done here as it's + // non non-trivial to do within collectProductFactor. + if (graph.empty() || std::any_of(graph.begin(), graph.end(), isNull)) { + return {nullptr, 0.0, nullptr, 0.0}; } // Expensive elimination of product factor. - auto result = + auto [conditional, factor] = EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE // Record whether there any continuous variables left - someContinuousLeft |= !result.second->empty(); + someContinuousLeft |= !factor->empty(); // We pass on the scalar unmodified. - return {result, scalar}; + return {conditional, conditional->negLogConstant(), factor, scalar}; }; // Perform elimination! - ResultTree eliminationResults(prunedProductFactor, eliminate); + const ResultTree eliminationResults(productFactor, eliminate); // If there are no more continuous parents we create a DiscreteFactor with the // error for each discrete choice. Otherwise, create a HybridGaussianFactor @@ -400,12 +403,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { ? createHybridGaussianFactor(eliminationResults, discreteSeparator) : createDiscreteFactor(eliminationResults, discreteSeparator); - // Create the HybridGaussianConditional from the conditionals - HybridGaussianConditional::Conditionals conditionals( - eliminationResults, - [](const ResultValuePair &pair) { return pair.first.first; }); - auto hybridGaussian = std::make_shared( - discreteSeparator, conditionals); + // Create the HybridGaussianConditional without re-calculating constants: + HybridGaussianConditional::FactorValuePairs pairs( + eliminationResults, [](const Result &result) -> GaussianFactorValuePair { + return {result.conditional, result.negLogK}; + }); + auto hybridGaussian = + std::make_shared(discreteSeparator, pairs); return {std::make_shared(hybridGaussian), newFactor}; }