diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index d0d2b8d15..86d74ca22 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -51,6 +51,8 @@ #include #include +// #define HYBRID_TIMING + namespace gtsam { template class EliminateableFactorGraph; @@ -256,13 +258,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // DiscreteFactor, with the error for each discrete choice. if (keysOfSeparator.empty()) { // TODO(Varun) Use the math from the iMHS_Math-1-indexed document - // TODO(Varun) The prob of a leaf should be computed from the full Bayes Net VectorValues empty_values; - auto factorError = [&](const GaussianFactor::shared_ptr &factor) { - if (!factor) return 0.0; // TODO(fan): does this make sense? - return exp(-factor->error(empty_values)); + auto factorProb = [&](const GaussianFactor::shared_ptr &factor) { + if (!factor) { + return 0.0; // If nullptr, return 0.0 probability + } else { + return 1.0; + } }; - DecisionTree fdt(separatorFactors, factorError); + DecisionTree fdt(separatorFactors, factorProb); auto discreteFactor = boost::make_shared(discreteSeparator, fdt); @@ -488,4 +492,65 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( return prob_tree; } +/* ************************************************************************ */ +boost::shared_ptr +HybridGaussianFactorGraph::eliminateHybridSequential() const { + Ordering continuous_ordering(this->continuousKeys()), + discrete_ordering(this->discreteKeys()); + + // Eliminate continuous + HybridBayesNet::shared_ptr bayesNet; + HybridGaussianFactorGraph::shared_ptr discreteGraph; + std::tie(bayesNet, discreteGraph) = + BaseEliminateable::eliminatePartialSequential( + continuous_ordering, EliminationTraitsType::DefaultEliminate); + + // Get the last continuous conditional which will have all the discrete keys + auto last_conditional = bayesNet->at(bayesNet->size() - 1); + // Get all the discrete assignments + DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + const std::vector assignments = + DiscreteValues::CartesianProduct(discrete_keys); + + // Reverse discrete keys order for correct tree construction + std::reverse(discrete_keys.begin(), discrete_keys.end()); + + // Create a decision tree of all the different VectorValues + std::vector vector_values; + for (const DiscreteValues &assignment : assignments) { + VectorValues values = bayesNet->optimize(assignment); + vector_values.push_back(boost::make_shared(values)); + } + DecisionTree delta_tree(discrete_keys, + vector_values); + + // Get the probPrime tree with the correct leaf probabilities + std::vector probPrimes; + for (const DiscreteValues &assignment : assignments) { + double error = 0.0; + VectorValues delta = *delta_tree(assignment); + for (auto factor : *this) { + if (factor->isHybrid()) { + auto f = boost::static_pointer_cast(factor); + error += f->error(delta, assignment); + + } else if (factor->isContinuous()) { + auto f = boost::static_pointer_cast(factor); + error += f->inner()->error(delta); + } + } + probPrimes.push_back(exp(-error)); + } + AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + + // Perform discrete elimination + HybridBayesNet::shared_ptr discreteBayesNet = + discreteGraph->eliminateSequential( + discrete_ordering, EliminationTraitsType::DefaultEliminate); + bayesNet->add(*discreteBayesNet); + + return bayesNet; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index c7e9aa60d..21a5740db 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -190,6 +190,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree probPrime( const VectorValues& continuousValues) const; + boost::shared_ptr eliminateHybridSequential() const; + /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys.