throw in optimize
parent
a1467c5e84
commit
3d4d750151
|
@ -26,48 +26,48 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
|
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
|
||||||
const KeySet &newFactorKeys) {
|
const KeySet &continuousKeys) {
|
||||||
// Get all the discrete keys from the factors
|
// Get all the discrete keys from the factors
|
||||||
KeySet allDiscrete = factors.discreteKeySet();
|
KeySet allDiscrete = factors.discreteKeySet();
|
||||||
|
|
||||||
// Create KeyVector with continuous keys followed by discrete keys.
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
KeyVector newKeysDiscreteLast;
|
KeyVector lastKeys;
|
||||||
|
|
||||||
// Insert continuous keys first.
|
// Insert continuous keys first.
|
||||||
for (auto &k : newFactorKeys) {
|
for (auto &k : continuousKeys) {
|
||||||
if (!allDiscrete.exists(k)) {
|
if (!allDiscrete.exists(k)) {
|
||||||
newKeysDiscreteLast.push_back(k);
|
lastKeys.push_back(k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert discrete keys at the end
|
// Insert discrete keys at the end
|
||||||
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||||
std::back_inserter(newKeysDiscreteLast));
|
std::back_inserter(lastKeys));
|
||||||
|
|
||||||
const VariableIndex index(factors);
|
const VariableIndex index(factors);
|
||||||
|
|
||||||
// Get an ordering where the new keys are eliminated last
|
// Get an ordering where the new keys are eliminated last
|
||||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||||
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
|
index, KeyVector(lastKeys.begin(), lastKeys.end()), true);
|
||||||
true);
|
|
||||||
return ordering;
|
return ordering;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors,
|
||||||
std::optional<size_t> maxNrLeaves,
|
std::optional<size_t> maxNrLeaves,
|
||||||
const std::optional<Ordering> given_ordering) {
|
const std::optional<Ordering> given_ordering) {
|
||||||
|
const KeySet originalNewFactorKeys = newFactors.keys();
|
||||||
#ifdef DEBUG_SMOOTHER
|
#ifdef DEBUG_SMOOTHER
|
||||||
std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size()
|
std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size()
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
std::cout << "newFactors size: " << graph.size() << std::endl;
|
std::cout << "newFactors size: " << newFactors.size() << std::endl;
|
||||||
#endif
|
#endif
|
||||||
HybridGaussianFactorGraph updatedGraph;
|
HybridGaussianFactorGraph updatedGraph;
|
||||||
// Add the necessary conditionals from the previous timestep(s).
|
// Add the necessary conditionals from the previous timestep(s).
|
||||||
std::tie(updatedGraph, hybridBayesNet_) =
|
std::tie(updatedGraph, hybridBayesNet_) =
|
||||||
addConditionals(graph, hybridBayesNet_);
|
addConditionals(newFactors, hybridBayesNet_);
|
||||||
#ifdef DEBUG_SMOOTHER
|
#ifdef DEBUG_SMOOTHER
|
||||||
// print size of graph, updatedGraph, hybridBayesNet_
|
// print size of newFactors, updatedGraph, hybridBayesNet_
|
||||||
std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl;
|
std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl;
|
||||||
std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size()
|
std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size()
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
@ -79,11 +79,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
// If no ordering provided, then we compute one
|
// If no ordering provided, then we compute one
|
||||||
if (!given_ordering.has_value()) {
|
if (!given_ordering.has_value()) {
|
||||||
// Get the keys from the new factors
|
// Get the keys from the new factors
|
||||||
const KeySet newFactorKeys = graph.keys();
|
const KeySet continuousKeysToInclude;// = newFactors.keys();
|
||||||
|
|
||||||
// Since updatedGraph now has all the connected conditionals,
|
// Since updatedGraph now has all the connected conditionals,
|
||||||
// we can get the correct ordering.
|
// we can get the correct ordering.
|
||||||
ordering = this->getOrdering(updatedGraph, newFactorKeys);
|
ordering = this->getOrdering(updatedGraph, continuousKeysToInclude);
|
||||||
} else {
|
} else {
|
||||||
ordering = *given_ordering;
|
ordering = *given_ordering;
|
||||||
}
|
}
|
||||||
|
@ -140,12 +140,15 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
||||||
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &newFactors,
|
||||||
const HybridBayesNet &hybridBayesNet) const {
|
const HybridBayesNet &hybridBayesNet) const {
|
||||||
HybridGaussianFactorGraph graph(originalGraph);
|
HybridGaussianFactorGraph graph(newFactors);
|
||||||
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
|
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
|
||||||
|
|
||||||
KeySet factorKeys = graph.keys();
|
KeySet involvedKeys = newFactors.keys();
|
||||||
|
auto involved = [&involvedKeys](const Key &key) {
|
||||||
|
return involvedKeys.find(key) != involvedKeys.end();
|
||||||
|
};
|
||||||
|
|
||||||
// If hybridBayesNet is not empty,
|
// If hybridBayesNet is not empty,
|
||||||
// it means we have conditionals to add to the factor graph.
|
// it means we have conditionals to add to the factor graph.
|
||||||
|
@ -167,12 +170,11 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
auto conditional = hybridBayesNet.at(i);
|
auto conditional = hybridBayesNet.at(i);
|
||||||
|
|
||||||
for (auto &key : conditional->frontals()) {
|
for (auto &key : conditional->frontals()) {
|
||||||
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
|
if (involved(key)) {
|
||||||
factorKeys.end()) {
|
// Add the conditional parents to involvedKeys
|
||||||
// Add the conditional parents to factorKeys
|
|
||||||
// so we add those conditionals too.
|
// so we add those conditionals too.
|
||||||
for (auto &&parentKey : conditional->parents()) {
|
for (auto &&parentKey : conditional->parents()) {
|
||||||
factorKeys.insert(parentKey);
|
involvedKeys.insert(parentKey);
|
||||||
}
|
}
|
||||||
// Break so we don't add parents twice.
|
// Break so we don't add parents twice.
|
||||||
break;
|
break;
|
||||||
|
@ -180,15 +182,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifdef DEBUG_SMOOTHER
|
#ifdef DEBUG_SMOOTHER
|
||||||
PrintKeySet(factorKeys);
|
PrintKeySet(involvedKeys);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (size_t i = 0; i < hybridBayesNet.size(); i++) {
|
for (size_t i = 0; i < hybridBayesNet.size(); i++) {
|
||||||
auto conditional = hybridBayesNet.at(i);
|
auto conditional = hybridBayesNet.at(i);
|
||||||
|
|
||||||
for (auto &key : conditional->frontals()) {
|
for (auto &key : conditional->frontals()) {
|
||||||
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
|
if (involved(key)) {
|
||||||
factorKeys.end()) {
|
|
||||||
newConditionals.push_back(conditional);
|
newConditionals.push_back(conditional);
|
||||||
|
|
||||||
// Remove the conditional from the updated Bayes net
|
// Remove the conditional from the updated Bayes net
|
||||||
|
@ -218,4 +219,21 @@ const HybridBayesNet &HybridSmoother::hybridBayesNet() const {
|
||||||
return hybridBayesNet_;
|
return hybridBayesNet_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridSmoother::optimize() const {
|
||||||
|
// Solve for the MPE
|
||||||
|
DiscreteValues mpe = hybridBayesNet_.mpe();
|
||||||
|
|
||||||
|
// Add fixed values to the MPE.
|
||||||
|
mpe.insert(fixedValues_);
|
||||||
|
|
||||||
|
// Given the MPE, compute the optimal continuous values.
|
||||||
|
GaussianBayesNet gbn = hybridBayesNet_.choose(mpe);
|
||||||
|
const VectorValues continuous = gbn.optimize();
|
||||||
|
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
|
||||||
|
throw std::runtime_error("At least one nullptr factor in hybridBayesNet_");
|
||||||
|
}
|
||||||
|
return HybridValues(continuous, mpe);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -108,16 +108,7 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
const HybridBayesNet& hybridBayesNet() const;
|
const HybridBayesNet& hybridBayesNet() const;
|
||||||
|
|
||||||
/// Optimize the hybrid Bayes Net, taking into accound fixed values.
|
/// Optimize the hybrid Bayes Net, taking into accound fixed values.
|
||||||
HybridValues optimize() const {
|
HybridValues optimize() const;
|
||||||
// Solve for the MPE
|
|
||||||
DiscreteValues mpe = hybridBayesNet_.mpe();
|
|
||||||
|
|
||||||
// Add fixed values to the MPE.
|
|
||||||
mpe.insert(fixedValues_);
|
|
||||||
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
|
||||||
return HybridValues(hybridBayesNet_.optimize(mpe), mpe);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue