From 8460452990965c36c87a16486ae11ccbdaec37d0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 20 Jan 2025 15:19:10 -0500 Subject: [PATCH 1/3] separate MPE method in Hybrid Bayes Net/Tree --- gtsam/hybrid/HybridBayesNet.cpp | 8 ++++++-- gtsam/hybrid/HybridBayesNet.h | 8 ++++++++ gtsam/hybrid/HybridBayesTree.cpp | 13 +++++++++---- gtsam/hybrid/HybridBayesTree.h | 8 ++++++++ 4 files changed, 31 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8668bedd6..f83435df2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -124,7 +124,7 @@ GaussianBayesNet HybridBayesNet::choose( } /* ************************************************************************* */ -HybridValues HybridBayesNet::optimize() const { +DiscreteValues HybridBayesNet::mpe() const { // Collect all the discrete factors to compute MPE DiscreteFactorGraph discrete_fg; @@ -140,9 +140,13 @@ HybridValues HybridBayesNet::optimize() const { } } } + return discrete_fg.optimize(); +} +/* ************************************************************************* */ +HybridValues HybridBayesNet::optimize() const { // Solve for the MPE - DiscreteValues mpe = discrete_fg.optimize(); + DiscreteValues mpe = this->mpe(); // Given the MPE, compute the optimal continuous values. return HybridValues(optimize(mpe), mpe); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 3e07c71ce..90e3a6814 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -146,6 +146,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { return evaluate(values); } + /** + * @brief Compute the Most Probable Explanation (MPE) + * of the discrete variables. + * + * @return DiscreteValues + */ + DiscreteValues mpe() const; + /** * @brief Solve the HybridBayesNet by first computing the MPE of all the * discrete variables and then optimizing the continuous variables based on diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 31d256d6f..c74930a8e 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -59,7 +59,7 @@ DiscreteValues HybridBayesTree::discreteMaxProduct( } /* ************************************************************************* */ -HybridValues HybridBayesTree::optimize() const { +DiscreteValues HybridBayesTree::mpe() const { DiscreteFactorGraph discrete_fg; DiscreteValues mpe; @@ -73,11 +73,16 @@ HybridValues HybridBayesTree::optimize() const { discrete_fg.push_back(discrete); mpe = discreteMaxProduct(discrete_fg); } else { - throw std::runtime_error( - "HybridBayesTree root is not discrete-only. Please check elimination " - "ordering or use continuous factor graph."); + mpe = DiscreteValues(); } + return mpe; +} + +/* ************************************************************************* */ +HybridValues HybridBayesTree::optimize() const { + DiscreteValues mpe = this->mpe(); + VectorValues values = optimize(mpe); return HybridValues(values, mpe); } diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index ec29f7b1e..d73b47f56 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -105,6 +105,14 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { */ VectorValues optimize(const DiscreteValues& assignment) const; + /** + * @brief Compute the Most Probable Explanation (MPE) + * of the discrete variables. + * + * @return DiscreteValues + */ + DiscreteValues mpe() const; + /** * @brief Prune the underlying Bayes tree. * From dc79f492b273978b0060f775f27628389544ef97 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 00:33:16 -0500 Subject: [PATCH 2/3] DiscreteValues::insert for single key-value pair --- gtsam/discrete/DiscreteValues.cpp | 20 +++++++++++++------- gtsam/discrete/DiscreteValues.h | 6 ++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp index e21fb71db..416dfb888 100644 --- a/gtsam/discrete/DiscreteValues.cpp +++ b/gtsam/discrete/DiscreteValues.cpp @@ -46,16 +46,22 @@ bool DiscreteValues::equals(const DiscreteValues& x, double tol) const { return true; } +/* ************************************************************************ */ +DiscreteValues& DiscreteValues::insert( + const std::pair& assignment) { + if (count(assignment.first)) { + throw std::out_of_range( + "Requested to insert a DiscreteValues into another DiscreteValues " + "that already contains one or more of its keys."); + } else { + this->emplace(assignment); + } + return *this; +} /* ************************************************************************ */ DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) { for (const auto& kv : values) { - if (count(kv.first)) { - throw std::out_of_range( - "Requested to insert a DiscreteValues into another DiscreteValues " - "that already contains one or more of its keys."); - } else { - this->emplace(kv); - } + this->insert(kv); } return *this; } diff --git a/gtsam/discrete/DiscreteValues.h b/gtsam/discrete/DiscreteValues.h index 9fdff014c..fa8a8a846 100644 --- a/gtsam/discrete/DiscreteValues.h +++ b/gtsam/discrete/DiscreteValues.h @@ -69,6 +69,12 @@ class GTSAM_EXPORT DiscreteValues : public Assignment { return Base::insert(value); } + /** + * Insert key-assignment pair. + * Throws an invalid_argument exception if + * any keys to be inserted are already used. */ + DiscreteValues& insert(const std::pair& assignment); + /** Insert all values from \c values. Throws an invalid_argument exception if * any keys to be inserted are already used. */ DiscreteValues& insert(const DiscreteValues& values); From 4b3c4093d5123b3fd0c6c43ba6600af0cd4a158a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 21 Jan 2025 00:34:47 -0500 Subject: [PATCH 3/3] format DiscreteMarginals.h --- gtsam/discrete/DiscreteMarginals.h | 38 +++++++++++++----------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index b97e60805..62c80e657 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -20,28 +20,26 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { - /** - * A class for computing marginals of variables in a DiscreteFactorGraph - * @ingroup discrete - */ +/** + * A class for computing marginals of variables in a DiscreteFactorGraph + * @ingroup discrete + */ class DiscreteMarginals { + protected: + DiscreteBayesTree::shared_ptr bayesTree_; - protected: - - DiscreteBayesTree::shared_ptr bayesTree_; - - public: - + public: DiscreteMarginals() {} /** Construct a marginals class. - * @param graph The factor graph defining the full joint distribution on all variables. + * @param graph The factor graph defining the full joint + * distribution on all variables. */ DiscreteMarginals(const DiscreteFactorGraph& graph) { bayesTree_ = graph.eliminateMultifrontal(); @@ -50,8 +48,8 @@ class DiscreteMarginals { /** Compute the marginal of a single variable */ DiscreteFactor::shared_ptr operator()(Key variable) const { // Compute marginal - DiscreteFactor::shared_ptr marginalFactor; - marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete); + DiscreteFactor::shared_ptr marginalFactor = + bayesTree_->marginalFactor(variable, &EliminateDiscrete); return marginalFactor; } @@ -61,19 +59,17 @@ class DiscreteMarginals { */ Vector marginalProbabilities(const DiscreteKey& key) const { // Compute marginal - DiscreteFactor::shared_ptr marginalFactor; - marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete); + DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first); - //Create result + // Create result Vector vResult(key.second); - for (size_t state = 0; state < key.second ; ++ state) { + for (size_t state = 0; state < key.second; ++state) { DiscreteValues values; values[key.first] = state; vResult(state) = (*marginalFactor)(values); } return vResult; } - - }; +}; } /* namespace gtsam */