gtsam/gtsam/hybrid/HybridNonlinearFactor.cpp

260 lines
10 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridNonlinearFactor.h
* @brief A set of nonlinear factors indexed by a set of discrete keys.
* @author Varun Agrawal
* @date Sep 12, 2024
*/
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/linear/NoiseModel.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <memory>
namespace gtsam {
/* *******************************************************************************/
struct HybridNonlinearFactor::ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs factorTree;
void copyOrCheckContinuousKeys(const NoiseModelFactor::shared_ptr& factor) {
if (!factor) return;
if (continuousKeys.empty()) {
continuousKeys = factor->keys();
} else if (factor->keys() != continuousKeys) {
throw std::runtime_error(
"HybridNonlinearFactor: all factors should have the same keys!");
}
}
ConstructorHelper(const DiscreteKey& discreteKey,
const std::vector<NoiseModelFactor::shared_ptr>& factors)
: discreteKeys({discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
// Extract continuous keys from the first non-null factor
for (const auto& factor : factors) {
pairs.emplace_back(factor, 0.0);
copyOrCheckContinuousKeys(factor);
}
factorTree = FactorValuePairs({discreteKey}, pairs);
}
ConstructorHelper(const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& pairs)
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto& pair : pairs) {
copyOrCheckContinuousKeys(pair.first);
}
factorTree = FactorValuePairs({discreteKey}, pairs);
}
ConstructorHelper(const DiscreteKeys& discreteKeys,
const FactorValuePairs& factorPairs)
: discreteKeys(discreteKeys), factorTree(factorPairs) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const NonlinearFactorValuePair& pair) {
copyOrCheckContinuousKeys(pair.first);
});
}
};
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(const ConstructorHelper& helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorTree) {}
HybridNonlinearFactor::HybridNonlinearFactor(
const DiscreteKey& discreteKey,
const std::vector<NoiseModelFactor::shared_ptr>& factors)
: HybridNonlinearFactor(ConstructorHelper(discreteKey, factors)) {}
HybridNonlinearFactor::HybridNonlinearFactor(
const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& pairs)
: HybridNonlinearFactor(ConstructorHelper(discreteKey, pairs)) {}
HybridNonlinearFactor::HybridNonlinearFactor(const DiscreteKeys& discreteKeys,
const FactorValuePairs& factors)
: HybridNonlinearFactor(ConstructorHelper(discreteKeys, factors)) {}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(
const Values& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const std::pair<sharedFactor, double>& f) {
auto [factor, val] = f;
return factor ? factor->error(continuousValues) + val
: std::numeric_limits<double>::infinity();
};
return {factors_, errorFunc};
}
/* *******************************************************************************/
double HybridNonlinearFactor::error(
const Values& continuousValues,
const DiscreteValues& discreteValues) const {
// Retrieve the factor corresponding to the assignment in discreteValues.
auto [factor, val] = factors_(discreteValues);
// Compute the error for the selected factor
const double factorError = factor->error(continuousValues);
return factorError + val;
}
/* *******************************************************************************/
double HybridNonlinearFactor::error(const HybridValues& values) const {
return error(values.nonlinear(), values.discrete());
}
/* *******************************************************************************/
size_t HybridNonlinearFactor::dim() const {
const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
auto [factor, val] = factors_(assignments.at(0));
return factor->dim();
}
/* *******************************************************************************/
void HybridNonlinearFactor::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
std::cout << (s.empty() ? "" : s + " ");
Base::print("", keyFormatter);
std::cout << "\nHybridNonlinearFactor\n";
auto valueFormatter = [](const std::pair<sharedFactor, double>& v) {
auto [factor, val] = v;
if (factor) {
return "Nonlinear factor on " + std::to_string(factor->size()) + " keys";
} else {
return std::string("nullptr");
}
};
factors_.print("", keyFormatter, valueFormatter);
}
/* *******************************************************************************/
bool HybridNonlinearFactor::equals(const HybridFactor& other,
double tol) const {
// We attempt a dynamic cast from HybridFactor to HybridNonlinearFactor. If
// it fails, return false.
if (!dynamic_cast<const HybridNonlinearFactor*>(&other)) return false;
// If the cast is successful, we'll properly construct a
// HybridNonlinearFactor object from `other`
const HybridNonlinearFactor& f(
static_cast<const HybridNonlinearFactor&>(other));
// Ensure that this HybridNonlinearFactor and `f` have the same `factors_`.
auto compare = [tol](const std::pair<sharedFactor, double>& a,
const std::pair<sharedFactor, double>& b) {
return a.first->equals(*b.first, tol) && (a.second == b.second);
};
if (!factors_.equals(f.factors_, compare)) return false;
// If everything above passes, and the keys_ and discreteKeys_
// member variables are identical, return true.
return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) &&
(discreteKeys_ == f.discreteKeys_));
}
/* *******************************************************************************/
GaussianFactor::shared_ptr HybridNonlinearFactor::linearize(
const Values& continuousValues,
const DiscreteValues& discreteValues) const {
auto factor = factors_(discreteValues).first;
return factor->linearize(continuousValues);
}
/* *******************************************************************************/
std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
const Values& continuousValues) const {
// functional to linearize each factor in the decision tree
auto linearizeDT =
[continuousValues](
const std::pair<sharedFactor, double>& f) -> GaussianFactorValuePair {
auto [factor, val] = f;
// Check if valid factor. If not, return null and infinite error.
if (!factor) {
return {nullptr, std::numeric_limits<double>::infinity()};
}
if (auto gaussian = std::dynamic_pointer_cast<noiseModel::Gaussian>(
factor->noiseModel())) {
return {factor->linearize(continuousValues),
val + gaussian->negLogConstant()};
} else {
throw std::runtime_error(
"HybridNonlinearFactor: linearize() only supports NoiseModelFactors "
"with Gaussian (or derived) noise models.");
}
};
DecisionTree<Key, std::pair<GaussianFactor::shared_ptr, double>>
linearized_factors(factors_, linearizeDT);
return std::make_shared<HybridGaussianFactor>(discreteKeys_,
linearized_factors);
}
/* *******************************************************************************/
HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
const DecisionTreeFactor& discreteProbs) const {
// Find keys in discreteProbs.keys() but not in this->keys():
std::set<Key> mine(this->keys().begin(), this->keys().end());
std::set<Key> theirs(discreteProbs.keys().begin(),
discreteProbs.keys().end());
std::vector<Key> diff;
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
std::back_inserter(diff));
// Find maximum probability value for every combination of our keys.
Ordering keys(diff);
auto max = discreteProbs.max(keys);
// Check the max value for every combination of our keys.
// If the max value is 0.0, we can prune the corresponding conditional.
auto pruner =
[&](const Assignment<Key>& choices,
const NonlinearFactorValuePair& pair) -> NonlinearFactorValuePair {
if (max->evaluate(choices) == 0.0)
return {nullptr, std::numeric_limits<double>::infinity()};
else
return pair;
};
FactorValuePairs prunedFactors = factors().apply(pruner);
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridNonlinearFactor::restrict(
const DiscreteValues& assignment) const {
auto restrictedFactors = factors_.restrict(assignment);
auto filtered = assignment.filter(discreteKeys_);
if (filtered.size() == discreteKeys_.size()) {
auto [nonlinearFactor, val] = factors_(filtered);
return nonlinearFactor;
} else {
auto remainingKeys = assignment.missingKeys(discreteKeys());
return std::make_shared<HybridNonlinearFactor>(remainingKeys,
factors_.restrict(filtered));
}
}
/* ************************************************************************ */
} // namespace gtsam