throw in optimize

release/4.3a0
Frank Dellaert 2025-01-30 10:57:41 -05:00
parent a1467c5e84
commit 3d4d750151
2 changed files with 42 additions and 33 deletions

View File

@ -26,48 +26,48 @@ namespace gtsam {
/* ************************************************************************* */
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
const KeySet &newFactorKeys) {
const KeySet &continuousKeys) {
// Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast;
KeyVector lastKeys;
// Insert continuous keys first.
for (auto &k : newFactorKeys) {
for (auto &k : continuousKeys) {
if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k);
lastKeys.push_back(k);
}
}
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));
std::back_inserter(lastKeys));
const VariableIndex index(factors);
// Get an ordering where the new keys are eliminated last
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
index, KeyVector(lastKeys.begin(), lastKeys.end()), true);
return ordering;
}
/* ************************************************************************* */
void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors,
std::optional<size_t> maxNrLeaves,
const std::optional<Ordering> given_ordering) {
const KeySet originalNewFactorKeys = newFactors.keys();
#ifdef DEBUG_SMOOTHER
std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size()
<< std::endl;
std::cout << "newFactors size: " << graph.size() << std::endl;
std::cout << "newFactors size: " << newFactors.size() << std::endl;
#endif
HybridGaussianFactorGraph updatedGraph;
// Add the necessary conditionals from the previous timestep(s).
std::tie(updatedGraph, hybridBayesNet_) =
addConditionals(graph, hybridBayesNet_);
addConditionals(newFactors, hybridBayesNet_);
#ifdef DEBUG_SMOOTHER
// print size of graph, updatedGraph, hybridBayesNet_
// print size of newFactors, updatedGraph, hybridBayesNet_
std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl;
std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size()
<< std::endl;
@ -79,11 +79,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
// If no ordering provided, then we compute one
if (!given_ordering.has_value()) {
// Get the keys from the new factors
const KeySet newFactorKeys = graph.keys();
const KeySet continuousKeysToInclude;// = newFactors.keys();
// Since updatedGraph now has all the connected conditionals,
// we can get the correct ordering.
ordering = this->getOrdering(updatedGraph, newFactorKeys);
ordering = this->getOrdering(updatedGraph, continuousKeysToInclude);
} else {
ordering = *given_ordering;
}
@ -140,12 +140,15 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
/* ************************************************************************* */
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &newFactors,
const HybridBayesNet &hybridBayesNet) const {
HybridGaussianFactorGraph graph(originalGraph);
HybridGaussianFactorGraph graph(newFactors);
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
KeySet factorKeys = graph.keys();
KeySet involvedKeys = newFactors.keys();
auto involved = [&involvedKeys](const Key &key) {
return involvedKeys.find(key) != involvedKeys.end();
};
// If hybridBayesNet is not empty,
// it means we have conditionals to add to the factor graph.
@ -167,12 +170,11 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
auto conditional = hybridBayesNet.at(i);
for (auto &key : conditional->frontals()) {
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
factorKeys.end()) {
// Add the conditional parents to factorKeys
if (involved(key)) {
// Add the conditional parents to involvedKeys
// so we add those conditionals too.
for (auto &&parentKey : conditional->parents()) {
factorKeys.insert(parentKey);
involvedKeys.insert(parentKey);
}
// Break so we don't add parents twice.
break;
@ -180,15 +182,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
}
}
#ifdef DEBUG_SMOOTHER
PrintKeySet(factorKeys);
PrintKeySet(involvedKeys);
#endif
for (size_t i = 0; i < hybridBayesNet.size(); i++) {
auto conditional = hybridBayesNet.at(i);
for (auto &key : conditional->frontals()) {
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
factorKeys.end()) {
if (involved(key)) {
newConditionals.push_back(conditional);
// Remove the conditional from the updated Bayes net
@ -218,4 +219,21 @@ const HybridBayesNet &HybridSmoother::hybridBayesNet() const {
return hybridBayesNet_;
}
/* ************************************************************************* */
HybridValues HybridSmoother::optimize() const {
// Solve for the MPE
DiscreteValues mpe = hybridBayesNet_.mpe();
// Add fixed values to the MPE.
mpe.insert(fixedValues_);
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hybridBayesNet_.choose(mpe);
const VectorValues continuous = gbn.optimize();
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
throw std::runtime_error("At least one nullptr factor in hybridBayesNet_");
}
return HybridValues(continuous, mpe);
}
} // namespace gtsam

View File

@ -108,16 +108,7 @@ class GTSAM_EXPORT HybridSmoother {
const HybridBayesNet& hybridBayesNet() const;
/// Optimize the hybrid Bayes Net, taking into accound fixed values.
HybridValues optimize() const {
// Solve for the MPE
DiscreteValues mpe = hybridBayesNet_.mpe();
// Add fixed values to the MPE.
mpe.insert(fixedValues_);
// Given the MPE, compute the optimal continuous values.
return HybridValues(hybridBayesNet_.optimize(mpe), mpe);
}
HybridValues optimize() const;
};
} // namespace gtsam