diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index b97e60805..62c80e657 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -20,28 +20,26 @@ #pragma once -#include -#include #include +#include +#include namespace gtsam { - /** - * A class for computing marginals of variables in a DiscreteFactorGraph - * @ingroup discrete - */ +/** + * A class for computing marginals of variables in a DiscreteFactorGraph + * @ingroup discrete + */ class DiscreteMarginals { + protected: + DiscreteBayesTree::shared_ptr bayesTree_; - protected: - - DiscreteBayesTree::shared_ptr bayesTree_; - - public: - + public: DiscreteMarginals() {} /** Construct a marginals class. - * @param graph The factor graph defining the full joint distribution on all variables. + * @param graph The factor graph defining the full joint + * distribution on all variables. */ DiscreteMarginals(const DiscreteFactorGraph& graph) { bayesTree_ = graph.eliminateMultifrontal(); @@ -50,8 +48,8 @@ class DiscreteMarginals { /** Compute the marginal of a single variable */ DiscreteFactor::shared_ptr operator()(Key variable) const { // Compute marginal - DiscreteFactor::shared_ptr marginalFactor; - marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete); + DiscreteFactor::shared_ptr marginalFactor = + bayesTree_->marginalFactor(variable, &EliminateDiscrete); return marginalFactor; } @@ -61,19 +59,17 @@ class DiscreteMarginals { */ Vector marginalProbabilities(const DiscreteKey& key) const { // Compute marginal - DiscreteFactor::shared_ptr marginalFactor; - marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete); + DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first); - //Create result + // Create result Vector vResult(key.second); - for (size_t state = 0; state < key.second ; ++ state) { + for (size_t state = 0; state < key.second; ++state) { DiscreteValues values; values[key.first] = state; vResult(state) = (*marginalFactor)(values); } return vResult; } - - }; +}; } /* namespace gtsam */