/* ---------------------------------------------------------------------------- * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * See LICENSE for the license information * -------------------------------------------------------------------------- */ /** * @file HybridNonlinearFactor.h * @brief A set of nonlinear factors indexed by a set of discrete keys. * @author Varun Agrawal * @date Sep 12, 2024 */ #include #include #include #include #include namespace gtsam { /* *******************************************************************************/ struct HybridNonlinearFactor::ConstructorHelper { KeyVector continuousKeys; // Continuous keys extracted from factors DiscreteKeys discreteKeys; // Discrete keys provided to the constructors FactorValuePairs factorTree; void copyOrCheckContinuousKeys(const NoiseModelFactor::shared_ptr& factor) { if (!factor) return; if (continuousKeys.empty()) { continuousKeys = factor->keys(); } else if (factor->keys() != continuousKeys) { throw std::runtime_error( "HybridNonlinearFactor: all factors should have the same keys!"); } } ConstructorHelper(const DiscreteKey& discreteKey, const std::vector& factors) : discreteKeys({discreteKey}) { std::vector pairs; // Extract continuous keys from the first non-null factor for (const auto& factor : factors) { pairs.emplace_back(factor, 0.0); copyOrCheckContinuousKeys(factor); } factorTree = FactorValuePairs({discreteKey}, pairs); } ConstructorHelper(const DiscreteKey& discreteKey, const std::vector& pairs) : discreteKeys({discreteKey}) { // Extract continuous keys from the first non-null factor for (const auto& pair : pairs) { copyOrCheckContinuousKeys(pair.first); } factorTree = FactorValuePairs({discreteKey}, pairs); } ConstructorHelper(const DiscreteKeys& discreteKeys, const FactorValuePairs& factorPairs) : discreteKeys(discreteKeys), factorTree(factorPairs) { // Extract continuous keys from the first non-null factor factorPairs.visit([&](const NonlinearFactorValuePair& pair) { copyOrCheckContinuousKeys(pair.first); }); } }; /* *******************************************************************************/ HybridNonlinearFactor::HybridNonlinearFactor(const ConstructorHelper& helper) : Base(helper.continuousKeys, helper.discreteKeys), factors_(helper.factorTree) {} HybridNonlinearFactor::HybridNonlinearFactor( const DiscreteKey& discreteKey, const std::vector& factors) : HybridNonlinearFactor(ConstructorHelper(discreteKey, factors)) {} HybridNonlinearFactor::HybridNonlinearFactor( const DiscreteKey& discreteKey, const std::vector& pairs) : HybridNonlinearFactor(ConstructorHelper(discreteKey, pairs)) {} HybridNonlinearFactor::HybridNonlinearFactor(const DiscreteKeys& discreteKeys, const FactorValuePairs& factors) : HybridNonlinearFactor(ConstructorHelper(discreteKeys, factors)) {} /* *******************************************************************************/ AlgebraicDecisionTree HybridNonlinearFactor::errorTree( const Values& continuousValues) const { // functor to convert from sharedFactor to double error value. auto errorFunc = [continuousValues](const std::pair& f) { auto [factor, val] = f; return factor ? factor->error(continuousValues) + val : std::numeric_limits::infinity(); }; return {factors_, errorFunc}; } /* *******************************************************************************/ double HybridNonlinearFactor::error( const Values& continuousValues, const DiscreteValues& discreteValues) const { // Retrieve the factor corresponding to the assignment in discreteValues. auto [factor, val] = factors_(discreteValues); // Compute the error for the selected factor const double factorError = factor->error(continuousValues); return factorError + val; } /* *******************************************************************************/ double HybridNonlinearFactor::error(const HybridValues& values) const { return error(values.nonlinear(), values.discrete()); } /* *******************************************************************************/ size_t HybridNonlinearFactor::dim() const { const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_); auto [factor, val] = factors_(assignments.at(0)); return factor->dim(); } /* *******************************************************************************/ void HybridNonlinearFactor::print(const std::string& s, const KeyFormatter& keyFormatter) const { std::cout << (s.empty() ? "" : s + " "); Base::print("", keyFormatter); std::cout << "\nHybridNonlinearFactor\n"; auto valueFormatter = [&keyFormatter](const std::pair& v) { auto [factor, val] = v; if (factor) { RedirectCout rd; factor->print("", keyFormatter); return rd.str(); } else { return std::string("nullptr"); } }; factors_.print("", keyFormatter, valueFormatter); } /* *******************************************************************************/ bool HybridNonlinearFactor::equals(const HybridFactor& other, double tol) const { // We attempt a dynamic cast from HybridFactor to HybridNonlinearFactor. If // it fails, return false. if (!dynamic_cast(&other)) return false; // If the cast is successful, we'll properly construct a // HybridNonlinearFactor object from `other` const HybridNonlinearFactor& f( static_cast(other)); // Ensure that this HybridNonlinearFactor and `f` have the same `factors_`. auto compare = [tol](const std::pair& a, const std::pair& b) { return a.first->equals(*b.first, tol) && (a.second == b.second); }; if (!factors_.equals(f.factors_, compare)) return false; // If everything above passes, and the keys_ and discreteKeys_ // member variables are identical, return true. return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) && (discreteKeys_ == f.discreteKeys_)); } /* *******************************************************************************/ GaussianFactor::shared_ptr HybridNonlinearFactor::linearize( const Values& continuousValues, const DiscreteValues& discreteValues) const { auto factor = factors_(discreteValues).first; return factor->linearize(continuousValues); } /* *******************************************************************************/ std::shared_ptr HybridNonlinearFactor::linearize( const Values& continuousValues) const { // functional to linearize each factor in the decision tree auto linearizeDT = [continuousValues]( const std::pair& f) -> GaussianFactorValuePair { auto [factor, val] = f; // Check if valid factor. If not, return null and infinite error. if (!factor) { return {nullptr, std::numeric_limits::infinity()}; } if (auto gaussian = std::dynamic_pointer_cast( factor->noiseModel())) { return {factor->linearize(continuousValues), val + gaussian->negLogConstant()}; } else { throw std::runtime_error( "HybridNonlinearFactor: linearize() only supports NoiseModelFactors " "with Gaussian (or derived) noise models."); } }; DecisionTree> linearized_factors(factors_, linearizeDT); return std::make_shared(discreteKeys_, linearized_factors); } /* *******************************************************************************/ HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune( const DecisionTreeFactor& discreteProbs) const { // Find keys in discreteProbs.keys() but not in this->keys(): std::set mine(this->keys().begin(), this->keys().end()); std::set theirs(discreteProbs.keys().begin(), discreteProbs.keys().end()); std::vector diff; std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::back_inserter(diff)); // Find maximum probability value for every combination of our keys. Ordering keys(diff); auto max = discreteProbs.max(keys); // Check the max value for every combination of our keys. // If the max value is 0.0, we can prune the corresponding conditional. auto pruner = [&](const Assignment& choices, const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair { if (max->evaluate(choices) == 0.0) return {nullptr, std::numeric_limits::infinity()}; else return pair; }; FactorValuePairs prunedFactors = factors().apply(pruner); return std::make_shared(discreteKeys(), prunedFactors); } /* ************************************************************************ */ std::shared_ptr HybridNonlinearFactor::restrict( const DiscreteValues& assignment) const { auto restrictedFactors = factors_.restrict(assignment); auto filtered = assignment.filter(discreteKeys_); if (filtered.size() == discreteKeys_.size()) { auto [nonlinearFactor, val] = factors_(filtered); return nonlinearFactor; } else { auto remainingKeys = assignment.missingKeys(discreteKeys()); return std::make_shared(remainingKeys, factors_.restrict(filtered)); } } /* ************************************************************************ */ } // namespace gtsam