Merge pull request #2062 from borglab/nonlinear-hybrid-smoother

Nonlinear Hybrid Smoother
release/4.3a0
Varun Agrawal 2025-03-18 22:39:49 -04:00 committed by GitHub
commit d01aaf0c84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 63 additions and 28 deletions

View File

@ -96,9 +96,17 @@ void HybridSmoother::removeFixedValues(
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors, void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
const Values &initial,
std::optional<size_t> maxNrLeaves, std::optional<size_t> maxNrLeaves,
const std::optional<Ordering> given_ordering) { const std::optional<Ordering> given_ordering) {
HybridGaussianFactorGraph linearizedFactors = *newFactors.linearize(initial);
// Record the new nonlinear factors and
// linearization point for relinearization
allFactors_.push_back(newFactors);
linearizationPoint_.insert_or_assign(initial);
const KeySet originalNewFactorKeys = newFactors.keys(); const KeySet originalNewFactorKeys = newFactors.keys();
#ifdef DEBUG_SMOOTHER #ifdef DEBUG_SMOOTHER
std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size() std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size()
@ -108,7 +116,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors,
HybridGaussianFactorGraph updatedGraph; HybridGaussianFactorGraph updatedGraph;
// Add the necessary conditionals from the previous timestep(s). // Add the necessary conditionals from the previous timestep(s).
std::tie(updatedGraph, hybridBayesNet_) = std::tie(updatedGraph, hybridBayesNet_) =
addConditionals(newFactors, hybridBayesNet_); addConditionals(linearizedFactors, hybridBayesNet_);
#ifdef DEBUG_SMOOTHER #ifdef DEBUG_SMOOTHER
// print size of newFactors, updatedGraph, hybridBayesNet_ // print size of newFactors, updatedGraph, hybridBayesNet_
std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl; std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl;
@ -283,4 +291,25 @@ HybridValues HybridSmoother::optimize() const {
return HybridValues(continuous, mpe); return HybridValues(continuous, mpe);
} }
/* ************************************************************************* */
void HybridSmoother::relinearize() {
allFactors_ = allFactors_.restrict(fixedValues_);
HybridGaussianFactorGraph::shared_ptr linearized =
allFactors_.linearize(linearizationPoint_);
HybridBayesNet::shared_ptr bayesNet = linearized->eliminateSequential();
HybridValues delta = bayesNet->optimize();
linearizationPoint_ = linearizationPoint_.retract(delta.continuous());
reInitialize(*bayesNet);
}
/* ************************************************************************* */
Values HybridSmoother::linearizationPoint() const {
return linearizationPoint_;
}
/* ************************************************************************* */
HybridNonlinearFactorGraph HybridSmoother::allFactors() const {
return allFactors_;
}
} // namespace gtsam } // namespace gtsam

View File

@ -19,6 +19,7 @@
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <optional> #include <optional>
@ -26,8 +27,10 @@ namespace gtsam {
class GTSAM_EXPORT HybridSmoother { class GTSAM_EXPORT HybridSmoother {
private: private:
HybridBayesNet hybridBayesNet_; HybridNonlinearFactorGraph allFactors_;
Values linearizationPoint_;
HybridBayesNet hybridBayesNet_;
/// The threshold above which we make a decision about a mode. /// The threshold above which we make a decision about a mode.
std::optional<double> marginalThreshold_; std::optional<double> marginalThreshold_;
DiscreteValues fixedValues_; DiscreteValues fixedValues_;
@ -73,12 +76,12 @@ class GTSAM_EXPORT HybridSmoother {
* @param graph The new factors, should be linear only * @param graph The new factors, should be linear only
* @param maxNrLeaves The maximum number of leaves in the new discrete factor, * @param maxNrLeaves The maximum number of leaves in the new discrete factor,
* if applicable * if applicable
* @param given_ordering The (optional) ordering for elimination, only * @param givenOrdering The (optional) ordering for elimination, only
* continuous variables are allowed * continuous variables are allowed
*/ */
void update(const HybridGaussianFactorGraph& graph, void update(const HybridNonlinearFactorGraph& graph, const Values& initial,
std::optional<size_t> maxNrLeaves = {}, std::optional<size_t> maxNrLeaves = {},
const std::optional<Ordering> given_ordering = {}); const std::optional<Ordering> givenOrdering = {});
/** /**
* @brief Get an elimination ordering which eliminates continuous * @brief Get an elimination ordering which eliminates continuous
@ -123,6 +126,16 @@ class GTSAM_EXPORT HybridSmoother {
/// Optimize the hybrid Bayes Net, taking into accound fixed values. /// Optimize the hybrid Bayes Net, taking into accound fixed values.
HybridValues optimize() const; HybridValues optimize() const;
/// Relinearize the nonlinear factor graph
/// with the latest linearization point.
void relinearize();
/// Return the current linearization point.
Values linearizationPoint() const;
/// Return all the recorded nonlinear factors
HybridNonlinearFactorGraph allFactors() const;
private: private:
/// Helper to compute the ordering if ordering is not given. /// Helper to compute the ordering if ordering is not given.
Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph, Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph,

View File

@ -283,10 +283,16 @@ class HybridSmoother {
void reInitialize(gtsam::HybridBayesNet& hybridBayesNet); void reInitialize(gtsam::HybridBayesNet& hybridBayesNet);
void update( void update(
const gtsam::HybridGaussianFactorGraph& graph, const gtsam::HybridNonlinearFactorGraph& graph,
const gtsam::Values& initial,
std::optional<size_t> maxNrLeaves = std::nullopt, std::optional<size_t> maxNrLeaves = std::nullopt,
const std::optional<gtsam::Ordering> given_ordering = std::nullopt); const std::optional<gtsam::Ordering> given_ordering = std::nullopt);
void relinearize();
gtsam::Values linearizationPoint() const;
gtsam::HybridNonlinearFactorGraph allFactors() const;
gtsam::Ordering getOrdering(const gtsam::HybridGaussianFactorGraph& factors, gtsam::Ordering getOrdering(const gtsam::HybridGaussianFactorGraph& factors,
const gtsam::KeySet& newFactorKeys); const gtsam::KeySet& newFactorKeys);

View File

@ -94,9 +94,7 @@ TEST(HybridSmoother, IncrementalSmoother) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k))); initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial); smoother.update(graph, initial, maxNrLeaves);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph // Clear all the factors from the graph
graph.resize(0); graph.resize(0);
@ -152,9 +150,7 @@ TEST(HybridSmoother, ValidPruningError) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k))); initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial); smoother.update(graph, initial, maxNrLeaves);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph // Clear all the factors from the graph
graph.resize(0); graph.resize(0);
@ -200,9 +196,7 @@ TEST(HybridSmoother, DeadModeRemoval) {
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k))); initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
HybridGaussianFactorGraph linearized = *graph.linearize(initial); smoother.update(graph, initial, maxNrLeaves);
smoother.update(linearized, maxNrLeaves);
// Clear all the factors from the graph // Clear all the factors from the graph
graph.resize(0); graph.resize(0);

View File

@ -194,7 +194,6 @@ class Experiment:
self.smoother_ = HybridSmoother(marginal_threshold) self.smoother_ = HybridSmoother(marginal_threshold)
self.new_factors_ = HybridNonlinearFactorGraph() self.new_factors_ = HybridNonlinearFactorGraph()
self.all_factors_ = HybridNonlinearFactorGraph()
self.initial_ = Values() self.initial_ = Values()
self.plot_hypotheses = plot_hypotheses self.plot_hypotheses = plot_hypotheses
@ -231,24 +230,18 @@ class Experiment:
"""Perform smoother update and optimize the graph.""" """Perform smoother update and optimize the graph."""
print(f"Smoother update: {self.new_factors_.size()}") print(f"Smoother update: {self.new_factors_.size()}")
before_update = time.time() before_update = time.time()
linearized = self.new_factors_.linearize(self.initial_) self.smoother_.update(self.new_factors_, self.initial_,
self.smoother_.update(linearized, max_num_hypotheses) max_num_hypotheses)
self.all_factors_.push_back(self.new_factors_)
self.new_factors_.resize(0) self.new_factors_.resize(0)
after_update = time.time() after_update = time.time()
return after_update - before_update return after_update - before_update
def reinitialize(self) -> float: def reinitialize(self) -> float:
"""Re-linearize, solve ALL, and re-initialize smoother.""" """Re-linearize, solve ALL, and re-initialize smoother."""
print(f"================= Re-Initialize: {self.all_factors_.size()}") print(f"================= Re-Initialize: {self.smoother_.allFactors().size()}")
before_update = time.time() before_update = time.time()
self.all_factors_ = self.all_factors_.restrict( self.smoother_.relinearize()
self.smoother_.fixedValues()) self.initial_ = self.smoother_.linearizationPoint()
linearized = self.all_factors_.linearize(self.initial_)
bayesNet = linearized.eliminateSequential()
delta: HybridValues = bayesNet.optimize()
self.initial_ = self.initial_.retract(delta.continuous())
self.smoother_.reInitialize(bayesNet)
after_update = time.time() after_update = time.time()
print(f"Took {after_update - before_update} seconds.") print(f"Took {after_update - before_update} seconds.")
return after_update - before_update return after_update - before_update