fix hybrid elimination
parent
a97d27e981
commit
a6d1a57478
|
@ -51,6 +51,8 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// #define HYBRID_TIMING
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||
|
@ -256,13 +258,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
// DiscreteFactor, with the error for each discrete choice.
|
||||
if (keysOfSeparator.empty()) {
|
||||
// TODO(Varun) Use the math from the iMHS_Math-1-indexed document
|
||||
// TODO(Varun) The prob of a leaf should be computed from the full Bayes Net
|
||||
VectorValues empty_values;
|
||||
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
|
||||
if (!factor) return 0.0; // TODO(fan): does this make sense?
|
||||
return exp(-factor->error(empty_values));
|
||||
auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
|
||||
if (!factor) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
return 1.0;
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
||||
|
||||
auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
@ -488,4 +492,65 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
|||
return prob_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateHybridSequential() const {
|
||||
Ordering continuous_ordering(this->continuousKeys()),
|
||||
discrete_ordering(this->discreteKeys());
|
||||
|
||||
// Eliminate continuous
|
||||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) =
|
||||
BaseEliminateable::eliminatePartialSequential(
|
||||
continuous_ordering, EliminationTraitsType::DefaultEliminate);
|
||||
|
||||
// Get the last continuous conditional which will have all the discrete keys
|
||||
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
|
||||
// Get all the discrete assignments
|
||||
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(discrete_keys);
|
||||
|
||||
// Reverse discrete keys order for correct tree construction
|
||||
std::reverse(discrete_keys.begin(), discrete_keys.end());
|
||||
|
||||
// Create a decision tree of all the different VectorValues
|
||||
std::vector<VectorValues::shared_ptr> vector_values;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
VectorValues values = bayesNet->optimize(assignment);
|
||||
vector_values.push_back(boost::make_shared<VectorValues>(values));
|
||||
}
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
|
||||
vector_values);
|
||||
|
||||
// Get the probPrime tree with the correct leaf probabilities
|
||||
std::vector<double> probPrimes;
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
double error = 0.0;
|
||||
VectorValues delta = *delta_tree(assignment);
|
||||
for (auto factor : *this) {
|
||||
if (factor->isHybrid()) {
|
||||
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
|
||||
error += f->error(delta, assignment);
|
||||
|
||||
} else if (factor->isContinuous()) {
|
||||
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
|
||||
error += f->inner()->error(delta);
|
||||
}
|
||||
}
|
||||
probPrimes.push_back(exp(-error));
|
||||
}
|
||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
|
||||
// Perform discrete elimination
|
||||
HybridBayesNet::shared_ptr discreteBayesNet =
|
||||
discreteGraph->eliminateSequential(
|
||||
discrete_ordering, EliminationTraitsType::DefaultEliminate);
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return bayesNet;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -190,6 +190,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
boost::shared_ptr<BayesNetType> eliminateHybridSequential() const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
|
|
Loading…
Reference in New Issue