Merge pull request #1986 from borglab/improvements

release/4.3a0
Varun Agrawal 2025-01-22 11:17:33 -05:00 committed by GitHub
commit 7dfdde30fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 67 additions and 34 deletions

View File

@ -20,9 +20,9 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -31,17 +31,15 @@ namespace gtsam {
* @ingroup discrete * @ingroup discrete
*/ */
class DiscreteMarginals { class DiscreteMarginals {
protected: protected:
DiscreteBayesTree::shared_ptr bayesTree_; DiscreteBayesTree::shared_ptr bayesTree_;
public: public:
DiscreteMarginals() {} DiscreteMarginals() {}
/** Construct a marginals class. /** Construct a marginals class.
* @param graph The factor graph defining the full joint distribution on all variables. * @param graph The factor graph defining the full joint
* distribution on all variables.
*/ */
DiscreteMarginals(const DiscreteFactorGraph& graph) { DiscreteMarginals(const DiscreteFactorGraph& graph) {
bayesTree_ = graph.eliminateMultifrontal(); bayesTree_ = graph.eliminateMultifrontal();
@ -50,8 +48,8 @@ class DiscreteMarginals {
/** Compute the marginal of a single variable */ /** Compute the marginal of a single variable */
DiscreteFactor::shared_ptr operator()(Key variable) const { DiscreteFactor::shared_ptr operator()(Key variable) const {
// Compute marginal // Compute marginal
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor =
marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete); bayesTree_->marginalFactor(variable, &EliminateDiscrete);
return marginalFactor; return marginalFactor;
} }
@ -61,8 +59,7 @@ class DiscreteMarginals {
*/ */
Vector marginalProbabilities(const DiscreteKey& key) const { Vector marginalProbabilities(const DiscreteKey& key) const {
// Compute marginal // Compute marginal
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
// Create result // Create result
Vector vResult(key.second); Vector vResult(key.second);
@ -73,7 +70,6 @@ class DiscreteMarginals {
} }
return vResult; return vResult;
} }
}; };
} /* namespace gtsam */ } /* namespace gtsam */

View File

@ -47,15 +47,21 @@ bool DiscreteValues::equals(const DiscreteValues& x, double tol) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) { DiscreteValues& DiscreteValues::insert(
for (const auto& kv : values) { const std::pair<Key, size_t>& assignment) {
if (count(kv.first)) { if (count(assignment.first)) {
throw std::out_of_range( throw std::out_of_range(
"Requested to insert a DiscreteValues into another DiscreteValues " "Requested to insert a DiscreteValues into another DiscreteValues "
"that already contains one or more of its keys."); "that already contains one or more of its keys.");
} else { } else {
this->emplace(kv); this->emplace(assignment);
} }
return *this;
}
/* ************************************************************************ */
DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) {
for (const auto& kv : values) {
this->insert(kv);
} }
return *this; return *this;
} }

View File

@ -69,6 +69,12 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
return Base::insert(value); return Base::insert(value);
} }
/**
* Insert key-assignment pair.
* Throws an invalid_argument exception if
* any keys to be inserted are already used. */
DiscreteValues& insert(const std::pair<Key, size_t>& assignment);
/** Insert all values from \c values. Throws an invalid_argument exception if /** Insert all values from \c values. Throws an invalid_argument exception if
* any keys to be inserted are already used. */ * any keys to be inserted are already used. */
DiscreteValues& insert(const DiscreteValues& values); DiscreteValues& insert(const DiscreteValues& values);

View File

@ -124,7 +124,7 @@ GaussianBayesNet HybridBayesNet::choose(
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { DiscreteValues HybridBayesNet::mpe() const {
// Collect all the discrete factors to compute MPE // Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
@ -140,9 +140,13 @@ HybridValues HybridBayesNet::optimize() const {
} }
} }
} }
return discrete_fg.optimize();
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE // Solve for the MPE
DiscreteValues mpe = discrete_fg.optimize(); DiscreteValues mpe = this->mpe();
// Given the MPE, compute the optimal continuous values. // Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe); return HybridValues(optimize(mpe), mpe);

View File

@ -146,6 +146,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
return evaluate(values); return evaluate(values);
} }
/**
* @brief Compute the Most Probable Explanation (MPE)
* of the discrete variables.
*
* @return DiscreteValues
*/
DiscreteValues mpe() const;
/** /**
* @brief Solve the HybridBayesNet by first computing the MPE of all the * @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on * discrete variables and then optimizing the continuous variables based on

View File

@ -59,7 +59,7 @@ DiscreteValues HybridBayesTree::discreteMaxProduct(
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const { DiscreteValues HybridBayesTree::mpe() const {
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
DiscreteValues mpe; DiscreteValues mpe;
@ -73,11 +73,16 @@ HybridValues HybridBayesTree::optimize() const {
discrete_fg.push_back(discrete); discrete_fg.push_back(discrete);
mpe = discreteMaxProduct(discrete_fg); mpe = discreteMaxProduct(discrete_fg);
} else { } else {
throw std::runtime_error( mpe = DiscreteValues();
"HybridBayesTree root is not discrete-only. Please check elimination "
"ordering or use continuous factor graph.");
} }
return mpe;
}
/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
DiscreteValues mpe = this->mpe();
VectorValues values = optimize(mpe); VectorValues values = optimize(mpe);
return HybridValues(values, mpe); return HybridValues(values, mpe);
} }

View File

@ -105,6 +105,14 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
*/ */
VectorValues optimize(const DiscreteValues& assignment) const; VectorValues optimize(const DiscreteValues& assignment) const;
/**
* @brief Compute the Most Probable Explanation (MPE)
* of the discrete variables.
*
* @return DiscreteValues
*/
DiscreteValues mpe() const;
/** /**
* @brief Prune the underlying Bayes tree. * @brief Prune the underlying Bayes tree.
* *