Merge pull request #2134 from borglab/wrap/discrete

Wrap DiscreteMarginals
release/4.3a0
Frank Dellaert 2025-05-15 12:26:04 -04:00 committed by GitHub
commit b14bae5ddb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 84 additions and 23 deletions

View File

@ -0,0 +1,61 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteMarginals.cpp
* @brief A class for computing marginals in a DiscreteFactorGraph
* @author Abhijit Kundu
* @author Richard Roberts
* @author Varun Agrawal
* @author Frank Dellaert
* @date June 4, 2012
*/
#include <gtsam/discrete/DiscreteMarginals.h>
namespace gtsam {
/* ************************************************************************* */
DiscreteMarginals::DiscreteMarginals(const DiscreteFactorGraph& graph) {
bayesTree_ = graph.eliminateMultifrontal();
}
/* ************************************************************************* */
DiscreteFactor::shared_ptr DiscreteMarginals::operator()(Key variable) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor =
bayesTree_->marginalFactor(variable, &EliminateDiscrete);
return marginalFactor;
}
/* ************************************************************************* */
Vector DiscreteMarginals::marginalProbabilities(const DiscreteKey& key) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
// Create result
Vector vResult(key.second);
for (size_t state = 0; state < key.second; ++state) {
DiscreteValues values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}
return vResult;
}
/* ************************************************************************* */
void DiscreteMarginals::print(const std::string& s,
const KeyFormatter formatter) const {
std::cout << (s.empty() ? "Discrete Marginals of:" : s + " ") << std::endl;
bayesTree_->print("", formatter);
}
} /* namespace gtsam */

View File

@ -14,6 +14,7 @@
* @brief A class for computing marginals in a DiscreteFactorGraph * @brief A class for computing marginals in a DiscreteFactorGraph
* @author Abhijit Kundu * @author Abhijit Kundu
* @author Richard Roberts * @author Richard Roberts
* @author Varun Agrawal
* @author Frank Dellaert * @author Frank Dellaert
* @date June 4, 2012 * @date June 4, 2012
*/ */
@ -30,7 +31,7 @@ namespace gtsam {
* A class for computing marginals of variables in a DiscreteFactorGraph * A class for computing marginals of variables in a DiscreteFactorGraph
* @ingroup discrete * @ingroup discrete
*/ */
class DiscreteMarginals { class GTSAM_EXPORT DiscreteMarginals {
protected: protected:
DiscreteBayesTree::shared_ptr bayesTree_; DiscreteBayesTree::shared_ptr bayesTree_;
@ -38,38 +39,23 @@ class DiscreteMarginals {
DiscreteMarginals() {} DiscreteMarginals() {}
/** Construct a marginals class. /** Construct a marginals class.
* @param graph The factor graph defining the full joint * @param graph The factor graph defining the full joint
* distribution on all variables. * distribution on all variables.
*/ */
DiscreteMarginals(const DiscreteFactorGraph& graph) { DiscreteMarginals(const DiscreteFactorGraph& graph);
bayesTree_ = graph.eliminateMultifrontal();
}
/** Compute the marginal of a single variable */ /** Compute the marginal of a single variable */
DiscreteFactor::shared_ptr operator()(Key variable) const { DiscreteFactor::shared_ptr operator()(Key variable) const;
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor =
bayesTree_->marginalFactor(variable, &EliminateDiscrete);
return marginalFactor;
}
/** Compute the marginal of a single variable /** Compute the marginal of a single variable
* @param key DiscreteKey of the Variable * @param key DiscreteKey of the Variable
* @return Vector of marginal probabilities * @return Vector of marginal probabilities
*/ */
Vector marginalProbabilities(const DiscreteKey& key) const { Vector marginalProbabilities(const DiscreteKey& key) const;
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
// Create result /// Print details
Vector vResult(key.second); void print(const std::string& s = "",
for (size_t state = 0; state < key.second; ++state) { const KeyFormatter formatter = DefaultKeyFormatter) const;
DiscreteValues values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}
return vResult;
}
}; };
} /* namespace gtsam */ } /* namespace gtsam */

View File

@ -494,4 +494,18 @@ class DiscreteSearch {
std::vector<gtsam::DiscreteSearchSolution> run(size_t K = 1) const; std::vector<gtsam::DiscreteSearchSolution> run(size_t K = 1) const;
}; };
#include <gtsam/discrete/DiscreteMarginals.h>
class DiscreteMarginals {
DiscreteMarginals();
DiscreteMarginals(const gtsam::DiscreteFactorGraph& graph);
gtsam::DiscreteFactor* operator()(gtsam::Key variable) const;
gtsam::Vector marginalProbabilities(const gtsam::DiscreteKey& key) const;
void print(const std::string& s = "",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
} // namespace gtsam } // namespace gtsam