split HybridNonlinearFactor into .h and .cpp

release/4.3a0
Varun Agrawal 2024-09-19 15:55:13 -04:00
parent 244661afb1
commit af06b33825
2 changed files with 151 additions and 87 deletions

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridNonlinearFactor.h
* @brief A set of nonlinear factors indexed by a set of discrete keys.
* @author Varun Agrawal
* @date Sep 12, 2024
*/
// #include <gtsam/base/utilities.h>
// #include <gtsam/discrete/DecisionTree-inl.h>
// #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
// #include <gtsam/hybrid/HybridValues.h>
// #include <gtsam/linear/GaussianFactor.h>
// #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& keys,
const DiscreteKeys& discreteKeys,
const Factors& factors)
: Base(keys, discreteKeys), factors_(factors) {}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(
const Values& continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const std::pair<sharedFactor, double>& f) {
auto [factor, val] = f;
return factor->error(continuousValues) + val;
};
DecisionTree<Key, double> result(factors_, errorFunc);
return result;
}
/* *******************************************************************************/
double HybridNonlinearFactor::error(
const Values& continuousValues,
const DiscreteValues& discreteValues) const {
// Retrieve the factor corresponding to the assignment in discreteValues.
auto [factor, val] = factors_(discreteValues);
// Compute the error for the selected factor
const double factorError = factor->error(continuousValues);
return factorError + val;
}
/* *******************************************************************************/
double HybridNonlinearFactor::error(const HybridValues& values) const {
return error(values.nonlinear(), values.discrete());
}
/* *******************************************************************************/
size_t HybridNonlinearFactor::dim() const {
const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
auto [factor, val] = factors_(assignments.at(0));
return factor->dim();
}
/* *******************************************************************************/
void HybridNonlinearFactor::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
std::cout << (s.empty() ? "" : s + " ");
Base::print("", keyFormatter);
std::cout << "\nHybridNonlinearFactor\n";
auto valueFormatter = [](const std::pair<sharedFactor, double>& v) {
auto [factor, val] = v;
if (factor) {
return "Nonlinear factor on " + std::to_string(factor->size()) + " keys";
} else {
return std::string("nullptr");
}
};
factors_.print("", keyFormatter, valueFormatter);
}
/* *******************************************************************************/
bool HybridNonlinearFactor::equals(const HybridFactor& other,
double tol) const {
// We attempt a dynamic cast from HybridFactor to HybridNonlinearFactor. If
// it fails, return false.
if (!dynamic_cast<const HybridNonlinearFactor*>(&other)) return false;
// If the cast is successful, we'll properly construct a
// HybridNonlinearFactor object from `other`
const HybridNonlinearFactor& f(
static_cast<const HybridNonlinearFactor&>(other));
// Ensure that this HybridNonlinearFactor and `f` have the same `factors_`.
auto compare = [tol](const std::pair<sharedFactor, double>& a,
const std::pair<sharedFactor, double>& b) {
return traits<NonlinearFactor>::Equals(*a.first, *b.first, tol) &&
(a.second == b.second);
};
if (!factors_.equals(f.factors_, compare)) return false;
// If everything above passes, and the keys_ and discreteKeys_
// member variables are identical, return true.
return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) &&
(discreteKeys_ == f.discreteKeys_));
}
/* *******************************************************************************/
GaussianFactor::shared_ptr HybridNonlinearFactor::linearize(
const Values& continuousValues,
const DiscreteValues& discreteValues) const {
auto factor = factors_(discreteValues).first;
return factor->linearize(continuousValues);
}
/* *******************************************************************************/
std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
const Values& continuousValues) const {
// functional to linearize each factor in the decision tree
auto linearizeDT =
[continuousValues](
const std::pair<sharedFactor, double>& f) -> GaussianFactorValuePair {
auto [factor, val] = f;
return {factor->linearize(continuousValues), val};
};
DecisionTree<Key, std::pair<GaussianFactor::shared_ptr, double>>
linearized_factors(factors_, linearizeDT);
return std::make_shared<HybridGaussianFactor>(continuousKeys_, discreteKeys_,
linearized_factors);
}
} // namespace gtsam

View File

@ -11,7 +11,7 @@
/** /**
* @file HybridNonlinearFactor.h * @file HybridNonlinearFactor.h
* @brief Nonlinear Mixture factor of continuous and discrete. * @brief A set of nonlinear factors indexed by a set of discrete keys.
* @author Kevin Doherty, kdoherty@mit.edu * @author Kevin Doherty, kdoherty@mit.edu
* @author Varun Agrawal * @author Varun Agrawal
* @date December 2021 * @date December 2021
@ -85,8 +85,7 @@ class HybridNonlinearFactor : public HybridFactor {
* @param factors Decision tree with of shared factors. * @param factors Decision tree with of shared factors.
*/ */
HybridNonlinearFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys, HybridNonlinearFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys,
const Factors& factors) const Factors& factors);
: Base(keys, discreteKeys), factors_(factors) {}
/** /**
* @brief Convenience constructor that generates the underlying factor * @brief Convenience constructor that generates the underlying factor
@ -140,16 +139,7 @@ class HybridNonlinearFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factor, and leaf values as the error. * as the factor, and leaf values as the error.
*/ */
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const { AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const;
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const std::pair<sharedFactor, double>& f) {
auto [factor, val] = f;
return factor->error(continuousValues) + val;
};
DecisionTree<Key, double> result(factors_, errorFunc);
return result;
}
/** /**
* @brief Compute error of factor given both continuous and discrete values. * @brief Compute error of factor given both continuous and discrete values.
@ -159,13 +149,7 @@ class HybridNonlinearFactor : public HybridFactor {
* @return double The error of this factor. * @return double The error of this factor.
*/ */
double error(const Values& continuousValues, double error(const Values& continuousValues,
const DiscreteValues& discreteValues) const { const DiscreteValues& discreteValues) const;
// Retrieve the factor corresponding to the assignment in discreteValues.
auto [factor, val] = factors_(discreteValues);
// Compute the error for the selected factor
const double factorError = factor->error(continuousValues);
return factorError + val;
}
/** /**
* @brief Compute error of factor given hybrid values. * @brief Compute error of factor given hybrid values.
@ -173,67 +157,24 @@ class HybridNonlinearFactor : public HybridFactor {
* @param values The continuous Values and the discrete assignment. * @param values The continuous Values and the discrete assignment.
* @return double The error of this factor. * @return double The error of this factor.
*/ */
double error(const HybridValues& values) const override { double error(const HybridValues& values) const override;
return error(values.nonlinear(), values.discrete());
}
/** /**
* @brief Get the dimension of the factor (number of rows on linearization). * @brief Get the dimension of the factor (number of rows on linearization).
* Returns the dimension of the first component factor. * Returns the dimension of the first component factor.
* @return size_t * @return size_t
*/ */
size_t dim() const { size_t dim() const;
const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
auto [factor, val] = factors_(assignments.at(0));
return factor->dim();
}
/// Testable /// Testable
/// @{ /// @{
/// print to stdout /// print to stdout
void print( void print(const std::string& s = "", const KeyFormatter& keyFormatter =
const std::string& s = "", DefaultKeyFormatter) const override;
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
std::cout << (s.empty() ? "" : s + " ");
Base::print("", keyFormatter);
std::cout << "\nHybridNonlinearFactor\n";
auto valueFormatter = [](const std::pair<sharedFactor, double>& v) {
auto [factor, val] = v;
if (factor) {
return "Nonlinear factor on " + std::to_string(factor->size()) +
" keys";
} else {
return std::string("nullptr");
}
};
factors_.print("", keyFormatter, valueFormatter);
}
/// Check equality /// Check equality
bool equals(const HybridFactor& other, double tol = 1e-9) const override { bool equals(const HybridFactor& other, double tol = 1e-9) const override;
// We attempt a dynamic cast from HybridFactor to HybridNonlinearFactor. If
// it fails, return false.
if (!dynamic_cast<const HybridNonlinearFactor*>(&other)) return false;
// If the cast is successful, we'll properly construct a
// HybridNonlinearFactor object from `other`
const HybridNonlinearFactor& f(
static_cast<const HybridNonlinearFactor&>(other));
// Ensure that this HybridNonlinearFactor and `f` have the same `factors_`.
auto compare = [tol](const std::pair<sharedFactor, double>& a,
const std::pair<sharedFactor, double>& b) {
return traits<NonlinearFactor>::Equals(*a.first, *b.first, tol) &&
(a.second == b.second);
};
if (!factors_.equals(f.factors_, compare)) return false;
// If everything above passes, and the keys_ and discreteKeys_
// member variables are identical, return true.
return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) &&
(discreteKeys_ == f.discreteKeys_));
}
/// @} /// @}
@ -241,28 +182,11 @@ class HybridNonlinearFactor : public HybridFactor {
/// discreteValues. /// discreteValues.
GaussianFactor::shared_ptr linearize( GaussianFactor::shared_ptr linearize(
const Values& continuousValues, const Values& continuousValues,
const DiscreteValues& discreteValues) const { const DiscreteValues& discreteValues) const;
auto factor = factors_(discreteValues).first;
return factor->linearize(continuousValues);
}
/// Linearize all the continuous factors to get a HybridGaussianFactor. /// Linearize all the continuous factors to get a HybridGaussianFactor.
std::shared_ptr<HybridGaussianFactor> linearize( std::shared_ptr<HybridGaussianFactor> linearize(
const Values& continuousValues) const { const Values& continuousValues) const;
// functional to linearize each factor in the decision tree
auto linearizeDT =
[continuousValues](const std::pair<sharedFactor, double>& f)
-> GaussianFactorValuePair {
auto [factor, val] = f;
return {factor->linearize(continuousValues), val};
};
DecisionTree<Key, std::pair<GaussianFactor::shared_ptr, double>>
linearized_factors(factors_, linearizeDT);
return std::make_shared<HybridGaussianFactor>(
continuousKeys_, discreteKeys_, linearized_factors);
}
}; };
} // namespace gtsam } // namespace gtsam