fix hybrid elimination

release/4.3a0
Varun Agrawal 2022-11-07 16:10:48 -05:00
parent a97d27e981
commit a6d1a57478
2 changed files with 72 additions and 5 deletions

View File

@ -51,6 +51,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
// #define HYBRID_TIMING
namespace gtsam { namespace gtsam {
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
@ -256,13 +258,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// DiscreteFactor, with the error for each discrete choice. // DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) { if (keysOfSeparator.empty()) {
// TODO(Varun) Use the math from the iMHS_Math-1-indexed document // TODO(Varun) Use the math from the iMHS_Math-1-indexed document
// TODO(Varun) The prob of a leaf should be computed from the full Bayes Net
VectorValues empty_values; VectorValues empty_values;
auto factorError = [&](const GaussianFactor::shared_ptr &factor) { auto factorProb = [&](const GaussianFactor::shared_ptr &factor) {
if (!factor) return 0.0; // TODO(fan): does this make sense? if (!factor) {
return exp(-factor->error(empty_values)); return 0.0; // If nullptr, return 0.0 probability
} else {
return 1.0;
}
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorError); DecisionTree<Key, double> fdt(separatorFactors, factorProb);
auto discreteFactor = auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
@ -488,4 +492,65 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
return prob_tree; return prob_tree;
} }
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential() const {
Ordering continuous_ordering(this->continuousKeys()),
discrete_ordering(this->discreteKeys());
// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(
continuous_ordering, EliminationTraitsType::DefaultEliminate);
// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
// Get all the discrete assignments
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = bayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);
// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
double error = 0.0;
VectorValues delta = *delta_tree(assignment);
for (auto factor : *this) {
if (factor->isHybrid()) {
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
error += f->error(delta, assignment);
} else if (factor->isContinuous()) {
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
error += f->inner()->error(delta);
}
}
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet =
discreteGraph->eliminateSequential(
discrete_ordering, EliminationTraitsType::DefaultEliminate);
bayesNet->add(*discreteBayesNet);
return bayesNet;
}
} // namespace gtsam } // namespace gtsam

View File

@ -190,6 +190,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> probPrime( AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const; const VectorValues& continuousValues) const;
boost::shared_ptr<BayesNetType> eliminateHybridSequential() const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys. * eliminated after the continuous keys.