throw in optimize

release/4.3a0
Frank Dellaert 2025-01-30 10:57:41 -05:00
parent a1467c5e84
commit 3d4d750151
2 changed files with 42 additions and 33 deletions

View File

@ -26,48 +26,48 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors, Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
const KeySet &newFactorKeys) { const KeySet &continuousKeys) {
// Get all the discrete keys from the factors // Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet(); KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys. // Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast; KeyVector lastKeys;
// Insert continuous keys first. // Insert continuous keys first.
for (auto &k : newFactorKeys) { for (auto &k : continuousKeys) {
if (!allDiscrete.exists(k)) { if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k); lastKeys.push_back(k);
} }
} }
// Insert discrete keys at the end // Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(), std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast)); std::back_inserter(lastKeys));
const VariableIndex index(factors); const VariableIndex index(factors);
// Get an ordering where the new keys are eliminated last // Get an ordering where the new keys are eliminated last
Ordering ordering = Ordering::ColamdConstrainedLast( Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), index, KeyVector(lastKeys.begin(), lastKeys.end()), true);
true);
return ordering; return ordering;
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridSmoother::update(const HybridGaussianFactorGraph &graph, void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors,
std::optional<size_t> maxNrLeaves, std::optional<size_t> maxNrLeaves,
const std::optional<Ordering> given_ordering) { const std::optional<Ordering> given_ordering) {
const KeySet originalNewFactorKeys = newFactors.keys();
#ifdef DEBUG_SMOOTHER #ifdef DEBUG_SMOOTHER
std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size() std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size()
<< std::endl; << std::endl;
std::cout << "newFactors size: " << graph.size() << std::endl; std::cout << "newFactors size: " << newFactors.size() << std::endl;
#endif #endif
HybridGaussianFactorGraph updatedGraph; HybridGaussianFactorGraph updatedGraph;
// Add the necessary conditionals from the previous timestep(s). // Add the necessary conditionals from the previous timestep(s).
std::tie(updatedGraph, hybridBayesNet_) = std::tie(updatedGraph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_); addConditionals(newFactors, hybridBayesNet_);
#ifdef DEBUG_SMOOTHER #ifdef DEBUG_SMOOTHER
// print size of graph, updatedGraph, hybridBayesNet_ // print size of newFactors, updatedGraph, hybridBayesNet_
std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl; std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl;
std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size() std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size()
<< std::endl; << std::endl;
@ -79,11 +79,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
// If no ordering provided, then we compute one // If no ordering provided, then we compute one
if (!given_ordering.has_value()) { if (!given_ordering.has_value()) {
// Get the keys from the new factors // Get the keys from the new factors
const KeySet newFactorKeys = graph.keys(); const KeySet continuousKeysToInclude;// = newFactors.keys();
// Since updatedGraph now has all the connected conditionals, // Since updatedGraph now has all the connected conditionals,
// we can get the correct ordering. // we can get the correct ordering.
ordering = this->getOrdering(updatedGraph, newFactorKeys); ordering = this->getOrdering(updatedGraph, continuousKeysToInclude);
} else { } else {
ordering = *given_ordering; ordering = *given_ordering;
} }
@ -140,12 +140,15 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
/* ************************************************************************* */ /* ************************************************************************* */
std::pair<HybridGaussianFactorGraph, HybridBayesNet> std::pair<HybridGaussianFactorGraph, HybridBayesNet>
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, HybridSmoother::addConditionals(const HybridGaussianFactorGraph &newFactors,
const HybridBayesNet &hybridBayesNet) const { const HybridBayesNet &hybridBayesNet) const {
HybridGaussianFactorGraph graph(originalGraph); HybridGaussianFactorGraph graph(newFactors);
HybridBayesNet updatedHybridBayesNet(hybridBayesNet); HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
KeySet factorKeys = graph.keys(); KeySet involvedKeys = newFactors.keys();
auto involved = [&involvedKeys](const Key &key) {
return involvedKeys.find(key) != involvedKeys.end();
};
// If hybridBayesNet is not empty, // If hybridBayesNet is not empty,
// it means we have conditionals to add to the factor graph. // it means we have conditionals to add to the factor graph.
@ -167,12 +170,11 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
auto conditional = hybridBayesNet.at(i); auto conditional = hybridBayesNet.at(i);
for (auto &key : conditional->frontals()) { for (auto &key : conditional->frontals()) {
if (std::find(factorKeys.begin(), factorKeys.end(), key) != if (involved(key)) {
factorKeys.end()) { // Add the conditional parents to involvedKeys
// Add the conditional parents to factorKeys
// so we add those conditionals too. // so we add those conditionals too.
for (auto &&parentKey : conditional->parents()) { for (auto &&parentKey : conditional->parents()) {
factorKeys.insert(parentKey); involvedKeys.insert(parentKey);
} }
// Break so we don't add parents twice. // Break so we don't add parents twice.
break; break;
@ -180,15 +182,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
} }
} }
#ifdef DEBUG_SMOOTHER #ifdef DEBUG_SMOOTHER
PrintKeySet(factorKeys); PrintKeySet(involvedKeys);
#endif #endif
for (size_t i = 0; i < hybridBayesNet.size(); i++) { for (size_t i = 0; i < hybridBayesNet.size(); i++) {
auto conditional = hybridBayesNet.at(i); auto conditional = hybridBayesNet.at(i);
for (auto &key : conditional->frontals()) { for (auto &key : conditional->frontals()) {
if (std::find(factorKeys.begin(), factorKeys.end(), key) != if (involved(key)) {
factorKeys.end()) {
newConditionals.push_back(conditional); newConditionals.push_back(conditional);
// Remove the conditional from the updated Bayes net // Remove the conditional from the updated Bayes net
@ -218,4 +219,21 @@ const HybridBayesNet &HybridSmoother::hybridBayesNet() const {
return hybridBayesNet_; return hybridBayesNet_;
} }
/* ************************************************************************* */
HybridValues HybridSmoother::optimize() const {
// Solve for the MPE
DiscreteValues mpe = hybridBayesNet_.mpe();
// Add fixed values to the MPE.
mpe.insert(fixedValues_);
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hybridBayesNet_.choose(mpe);
const VectorValues continuous = gbn.optimize();
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
throw std::runtime_error("At least one nullptr factor in hybridBayesNet_");
}
return HybridValues(continuous, mpe);
}
} // namespace gtsam } // namespace gtsam

View File

@ -108,16 +108,7 @@ class GTSAM_EXPORT HybridSmoother {
const HybridBayesNet& hybridBayesNet() const; const HybridBayesNet& hybridBayesNet() const;
/// Optimize the hybrid Bayes Net, taking into accound fixed values. /// Optimize the hybrid Bayes Net, taking into accound fixed values.
HybridValues optimize() const { HybridValues optimize() const;
// Solve for the MPE
DiscreteValues mpe = hybridBayesNet_.mpe();
// Add fixed values to the MPE.
mpe.insert(fixedValues_);
// Given the MPE, compute the optimal continuous values.
return HybridValues(hybridBayesNet_.optimize(mpe), mpe);
}
}; };
} // namespace gtsam } // namespace gtsam