diff --git a/gtsam/discrete/DiscreteMarginals.cpp b/gtsam/discrete/DiscreteMarginals.cpp new file mode 100644 index 000000000..60f9ba896 --- /dev/null +++ b/gtsam/discrete/DiscreteMarginals.cpp @@ -0,0 +1,61 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteMarginals.cpp + * @brief A class for computing marginals in a DiscreteFactorGraph + * @author Abhijit Kundu + * @author Richard Roberts + * @author Varun Agrawal + * @author Frank Dellaert + * @date June 4, 2012 + */ + +#include + +namespace gtsam { + +/* ************************************************************************* */ +DiscreteMarginals::DiscreteMarginals(const DiscreteFactorGraph& graph) { + bayesTree_ = graph.eliminateMultifrontal(); +} + +/* ************************************************************************* */ +DiscreteFactor::shared_ptr DiscreteMarginals::operator()(Key variable) const { + // Compute marginal + DiscreteFactor::shared_ptr marginalFactor = + bayesTree_->marginalFactor(variable, &EliminateDiscrete); + return marginalFactor; +} + +/* ************************************************************************* */ +Vector DiscreteMarginals::marginalProbabilities(const DiscreteKey& key) const { + // Compute marginal + DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first); + + // Create result + Vector vResult(key.second); + for (size_t state = 0; state < key.second; ++state) { + DiscreteValues values; + values[key.first] = state; + vResult(state) = (*marginalFactor)(values); + } + return vResult; +} + +/* ************************************************************************* */ +void DiscreteMarginals::print(const std::string& s, + const KeyFormatter formatter) const { + std::cout << (s.empty() ? "Discrete Marginals of:" : s + " ") << std::endl; + bayesTree_->print("", formatter); +} + +} /* namespace gtsam */ diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index 62c80e657..91082f491 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -14,6 +14,7 @@ * @brief A class for computing marginals in a DiscreteFactorGraph * @author Abhijit Kundu * @author Richard Roberts + * @author Varun Agrawal * @author Frank Dellaert * @date June 4, 2012 */ @@ -30,7 +31,7 @@ namespace gtsam { * A class for computing marginals of variables in a DiscreteFactorGraph * @ingroup discrete */ -class DiscreteMarginals { +class GTSAM_EXPORT DiscreteMarginals { protected: DiscreteBayesTree::shared_ptr bayesTree_; @@ -38,38 +39,23 @@ class DiscreteMarginals { DiscreteMarginals() {} /** Construct a marginals class. - * @param graph The factor graph defining the full joint + * @param graph The factor graph defining the full joint * distribution on all variables. */ - DiscreteMarginals(const DiscreteFactorGraph& graph) { - bayesTree_ = graph.eliminateMultifrontal(); - } + DiscreteMarginals(const DiscreteFactorGraph& graph); /** Compute the marginal of a single variable */ - DiscreteFactor::shared_ptr operator()(Key variable) const { - // Compute marginal - DiscreteFactor::shared_ptr marginalFactor = - bayesTree_->marginalFactor(variable, &EliminateDiscrete); - return marginalFactor; - } + DiscreteFactor::shared_ptr operator()(Key variable) const; /** Compute the marginal of a single variable * @param key DiscreteKey of the Variable * @return Vector of marginal probabilities */ - Vector marginalProbabilities(const DiscreteKey& key) const { - // Compute marginal - DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first); + Vector marginalProbabilities(const DiscreteKey& key) const; - // Create result - Vector vResult(key.second); - for (size_t state = 0; state < key.second; ++state) { - DiscreteValues values; - values[key.first] = state; - vResult(state) = (*marginalFactor)(values); - } - return vResult; - } + /// Print details + void print(const std::string& s = "", + const KeyFormatter formatter = DefaultKeyFormatter) const; }; } /* namespace gtsam */ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 2f15de4c6..27f1fdfa1 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -494,4 +494,18 @@ class DiscreteSearch { std::vector run(size_t K = 1) const; }; +#include + +class DiscreteMarginals { + DiscreteMarginals(); + DiscreteMarginals(const gtsam::DiscreteFactorGraph& graph); + + gtsam::DiscreteFactor* operator()(gtsam::Key variable) const; + gtsam::Vector marginalProbabilities(const gtsam::DiscreteKey& key) const; + + void print(const std::string& s = "", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + } // namespace gtsam