format DiscreteMarginals.h

release/4.3a0
Varun Agrawal 2025-01-21 00:34:47 -05:00
parent dc79f492b2
commit 4b3c4093d5
1 changed files with 17 additions and 21 deletions

View File

@ -20,28 +20,26 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
namespace gtsam { namespace gtsam {
/** /**
* A class for computing marginals of variables in a DiscreteFactorGraph * A class for computing marginals of variables in a DiscreteFactorGraph
* @ingroup discrete * @ingroup discrete
*/ */
class DiscreteMarginals { class DiscreteMarginals {
protected:
DiscreteBayesTree::shared_ptr bayesTree_;
protected: public:
DiscreteBayesTree::shared_ptr bayesTree_;
public:
DiscreteMarginals() {} DiscreteMarginals() {}
/** Construct a marginals class. /** Construct a marginals class.
* @param graph The factor graph defining the full joint distribution on all variables. * @param graph The factor graph defining the full joint
* distribution on all variables.
*/ */
DiscreteMarginals(const DiscreteFactorGraph& graph) { DiscreteMarginals(const DiscreteFactorGraph& graph) {
bayesTree_ = graph.eliminateMultifrontal(); bayesTree_ = graph.eliminateMultifrontal();
@ -50,8 +48,8 @@ class DiscreteMarginals {
/** Compute the marginal of a single variable */ /** Compute the marginal of a single variable */
DiscreteFactor::shared_ptr operator()(Key variable) const { DiscreteFactor::shared_ptr operator()(Key variable) const {
// Compute marginal // Compute marginal
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor =
marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete); bayesTree_->marginalFactor(variable, &EliminateDiscrete);
return marginalFactor; return marginalFactor;
} }
@ -61,19 +59,17 @@ class DiscreteMarginals {
*/ */
Vector marginalProbabilities(const DiscreteKey& key) const { Vector marginalProbabilities(const DiscreteKey& key) const {
// Compute marginal // Compute marginal
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
//Create result // Create result
Vector vResult(key.second); Vector vResult(key.second);
for (size_t state = 0; state < key.second ; ++ state) { for (size_t state = 0; state < key.second; ++state) {
DiscreteValues values; DiscreteValues values;
values[key.first] = state; values[key.first] = state;
vResult(state) = (*marginalFactor)(values); vResult(state) = (*marginalFactor)(values);
} }
return vResult; return vResult;
} }
};
};
} /* namespace gtsam */ } /* namespace gtsam */