format DiscreteMarginals.h
parent
dc79f492b2
commit
4b3c4093d5
|
@ -20,28 +20,26 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class for computing marginals of variables in a DiscreteFactorGraph
|
* A class for computing marginals of variables in a DiscreteFactorGraph
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class DiscreteMarginals {
|
class DiscreteMarginals {
|
||||||
|
protected:
|
||||||
|
DiscreteBayesTree::shared_ptr bayesTree_;
|
||||||
|
|
||||||
protected:
|
public:
|
||||||
|
|
||||||
DiscreteBayesTree::shared_ptr bayesTree_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
DiscreteMarginals() {}
|
DiscreteMarginals() {}
|
||||||
|
|
||||||
/** Construct a marginals class.
|
/** Construct a marginals class.
|
||||||
* @param graph The factor graph defining the full joint distribution on all variables.
|
* @param graph The factor graph defining the full joint
|
||||||
|
* distribution on all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
||||||
bayesTree_ = graph.eliminateMultifrontal();
|
bayesTree_ = graph.eliminateMultifrontal();
|
||||||
|
@ -50,8 +48,8 @@ class DiscreteMarginals {
|
||||||
/** Compute the marginal of a single variable */
|
/** Compute the marginal of a single variable */
|
||||||
DiscreteFactor::shared_ptr operator()(Key variable) const {
|
DiscreteFactor::shared_ptr operator()(Key variable) const {
|
||||||
// Compute marginal
|
// Compute marginal
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor =
|
||||||
marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete);
|
bayesTree_->marginalFactor(variable, &EliminateDiscrete);
|
||||||
return marginalFactor;
|
return marginalFactor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,19 +59,17 @@ class DiscreteMarginals {
|
||||||
*/
|
*/
|
||||||
Vector marginalProbabilities(const DiscreteKey& key) const {
|
Vector marginalProbabilities(const DiscreteKey& key) const {
|
||||||
// Compute marginal
|
// Compute marginal
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
|
||||||
marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
|
|
||||||
|
|
||||||
//Create result
|
// Create result
|
||||||
Vector vResult(key.second);
|
Vector vResult(key.second);
|
||||||
for (size_t state = 0; state < key.second ; ++ state) {
|
for (size_t state = 0; state < key.second; ++state) {
|
||||||
DiscreteValues values;
|
DiscreteValues values;
|
||||||
values[key.first] = state;
|
values[key.first] = state;
|
||||||
vResult(state) = (*marginalFactor)(values);
|
vResult(state) = (*marginalFactor)(values);
|
||||||
}
|
}
|
||||||
return vResult;
|
return vResult;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
};
|
|
||||||
|
|
||||||
} /* namespace gtsam */
|
} /* namespace gtsam */
|
||||||
|
|
Loading…
Reference in New Issue