From 1815433cbbde1052237427f774a086b1eabe8430 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 7 Nov 2022 18:29:49 -0500 Subject: [PATCH] add methods to perform correct elimination --- gtsam/hybrid/HybridBayesNet.cpp | 6 ++ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 110 +++++++++++++++------ gtsam/hybrid/HybridGaussianFactorGraph.h | 39 ++++++++ 3 files changed, 123 insertions(+), 32 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f0d53c416..7338873bc 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -229,6 +229,12 @@ HybridValues HybridBayesNet::optimize() const { /* ************************************************************************* */ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { GaussianBayesNet gbn = this->choose(assignment); + + // Check if there exists a nullptr in the GaussianBayesNet + // If yes, return an empty VectorValues + if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) { + return VectorValues(); + } return gbn.optimize(); } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 86d74ca22..e018d1046 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -492,6 +492,75 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::probPrime( return prob_tree; } +/* ************************************************************************ */ +DecisionTree +HybridGaussianFactorGraph::continuousDelta( + const DiscreteKeys &discrete_keys, + const boost::shared_ptr &continuousBayesNet, + const std::vector &assignments) const { + // Create a decision tree of all the different VectorValues + std::vector vector_values; + for (const DiscreteValues &assignment : assignments) { + VectorValues values = continuousBayesNet->optimize(assignment); + vector_values.push_back(boost::make_shared(values)); + } + DecisionTree delta_tree(discrete_keys, + vector_values); + + return delta_tree; +} + +/* ************************************************************************ */ +AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( + const DiscreteKeys &discrete_keys, + const boost::shared_ptr &continuousBayesNet, + const std::vector &assignments) const { + // Create a decision tree of all the different VectorValues + DecisionTree delta_tree = + this->continuousDelta(discrete_keys, continuousBayesNet, assignments); + + // Get the probPrime tree with the correct leaf probabilities + std::vector probPrimes; + for (const DiscreteValues &assignment : assignments) { + VectorValues delta = *delta_tree(assignment); + + // If VectorValues is empty, it means this is a pruned branch. + // Set thr probPrime to 0.0. + if (delta.size() == 0) { + probPrimes.push_back(0.0); + continue; + } + + double error = 0.0; + + for (size_t idx = 0; idx < size(); idx++) { + auto factor = factors_.at(idx); + + if (factor->isHybrid()) { + if (auto c = boost::dynamic_pointer_cast(factor)) { + error += c->asMixture()->error(delta, assignment); + } + if (auto f = + boost::dynamic_pointer_cast(factor)) { + error += f->error(delta, assignment); + } + + } else if (factor->isContinuous()) { + if (auto f = + boost::dynamic_pointer_cast(factor)) { + error += f->inner()->error(delta); + } + if (auto cg = boost::dynamic_pointer_cast(factor)) { + error += cg->asGaussian()->error(delta); + } + } + } + probPrimes.push_back(exp(-error)); + } + AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); + return probPrimeTree; +} + /* ************************************************************************ */ boost::shared_ptr HybridGaussianFactorGraph::eliminateHybridSequential() const { @@ -502,52 +571,29 @@ HybridGaussianFactorGraph::eliminateHybridSequential() const { HybridBayesNet::shared_ptr bayesNet; HybridGaussianFactorGraph::shared_ptr discreteGraph; std::tie(bayesNet, discreteGraph) = - BaseEliminateable::eliminatePartialSequential( - continuous_ordering, EliminationTraitsType::DefaultEliminate); + BaseEliminateable::eliminatePartialSequential(continuous_ordering); // Get the last continuous conditional which will have all the discrete keys auto last_conditional = bayesNet->at(bayesNet->size() - 1); - // Get all the discrete assignments DiscreteKeys discrete_keys = last_conditional->discreteKeys(); + const std::vector assignments = DiscreteValues::CartesianProduct(discrete_keys); + // Save a copy of the original discrete key ordering + DiscreteKeys orig_discrete_keys(discrete_keys); // Reverse discrete keys order for correct tree construction std::reverse(discrete_keys.begin(), discrete_keys.end()); - // Create a decision tree of all the different VectorValues - std::vector vector_values; - for (const DiscreteValues &assignment : assignments) { - VectorValues values = bayesNet->optimize(assignment); - vector_values.push_back(boost::make_shared(values)); - } - DecisionTree delta_tree(discrete_keys, - vector_values); + AlgebraicDecisionTree probPrimeTree = + continuousProbPrimes(discrete_keys, bayesNet, assignments); - // Get the probPrime tree with the correct leaf probabilities - std::vector probPrimes; - for (const DiscreteValues &assignment : assignments) { - double error = 0.0; - VectorValues delta = *delta_tree(assignment); - for (auto factor : *this) { - if (factor->isHybrid()) { - auto f = boost::static_pointer_cast(factor); - error += f->error(delta, assignment); - - } else if (factor->isContinuous()) { - auto f = boost::static_pointer_cast(factor); - error += f->inner()->error(delta); - } - } - probPrimes.push_back(exp(-error)); - } - AlgebraicDecisionTree probPrimeTree(discrete_keys, probPrimes); - discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + discreteGraph->add(DecisionTreeFactor(orig_discrete_keys, probPrimeTree)); // Perform discrete elimination HybridBayesNet::shared_ptr discreteBayesNet = - discreteGraph->eliminateSequential( - discrete_ordering, EliminationTraitsType::DefaultEliminate); + discreteGraph->eliminateSequential(discrete_ordering); + bayesNet->add(*discreteBayesNet); return bayesNet; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 21a5740db..8c387ec9b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace gtsam { @@ -190,6 +191,44 @@ class GTSAM_EXPORT HybridGaussianFactorGraph AlgebraicDecisionTree probPrime( const VectorValues& continuousValues) const; + /** + * @brief Compute the VectorValues solution for the continuous variables for + * each mode. + * + * @param discrete_keys The discrete keys which form all the modes. + * @param continuousBayesNet The Bayes Net representing the continuous + * eliminated variables. + * @param assignments List of all discrete assignments to create the final + * decision tree. + * @return DecisionTree + */ + DecisionTree continuousDelta( + const DiscreteKeys& discrete_keys, + const boost::shared_ptr& continuousBayesNet, + const std::vector& assignments) const; + + /** + * @brief Compute the unnormalized probabilities of the continuous variables + * for each of the modes. + * + * @param discrete_keys The discrete keys which form all the modes. + * @param continuousBayesNet The Bayes Net representing the continuous + * eliminated variables. + * @param assignments List of all discrete assignments to create the final + * decision tree. + * @return AlgebraicDecisionTree + */ + AlgebraicDecisionTree continuousProbPrimes( + const DiscreteKeys& discrete_keys, + const boost::shared_ptr& continuousBayesNet, + const std::vector& assignments) const; + + /** + * @brief Custom elimination function which computes the correct + * continuous probabilities. + * + * @return boost::shared_ptr + */ boost::shared_ptr eliminateHybridSequential() const; /**