Merge pull request #1986 from borglab/improvements
commit
7dfdde30fd
|
@ -20,9 +20,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
|
||||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -31,17 +31,15 @@ namespace gtsam {
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class DiscreteMarginals {
|
class DiscreteMarginals {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
DiscreteBayesTree::shared_ptr bayesTree_;
|
DiscreteBayesTree::shared_ptr bayesTree_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
DiscreteMarginals() {}
|
DiscreteMarginals() {}
|
||||||
|
|
||||||
/** Construct a marginals class.
|
/** Construct a marginals class.
|
||||||
* @param graph The factor graph defining the full joint distribution on all variables.
|
* @param graph The factor graph defining the full joint
|
||||||
|
* distribution on all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
||||||
bayesTree_ = graph.eliminateMultifrontal();
|
bayesTree_ = graph.eliminateMultifrontal();
|
||||||
|
@ -50,8 +48,8 @@ class DiscreteMarginals {
|
||||||
/** Compute the marginal of a single variable */
|
/** Compute the marginal of a single variable */
|
||||||
DiscreteFactor::shared_ptr operator()(Key variable) const {
|
DiscreteFactor::shared_ptr operator()(Key variable) const {
|
||||||
// Compute marginal
|
// Compute marginal
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor =
|
||||||
marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete);
|
bayesTree_->marginalFactor(variable, &EliminateDiscrete);
|
||||||
return marginalFactor;
|
return marginalFactor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,8 +59,7 @@ class DiscreteMarginals {
|
||||||
*/
|
*/
|
||||||
Vector marginalProbabilities(const DiscreteKey& key) const {
|
Vector marginalProbabilities(const DiscreteKey& key) const {
|
||||||
// Compute marginal
|
// Compute marginal
|
||||||
DiscreteFactor::shared_ptr marginalFactor;
|
DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
|
||||||
marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
|
|
||||||
|
|
||||||
// Create result
|
// Create result
|
||||||
Vector vResult(key.second);
|
Vector vResult(key.second);
|
||||||
|
@ -73,7 +70,6 @@ class DiscreteMarginals {
|
||||||
}
|
}
|
||||||
return vResult;
|
return vResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} /* namespace gtsam */
|
} /* namespace gtsam */
|
||||||
|
|
|
@ -47,15 +47,21 @@ bool DiscreteValues::equals(const DiscreteValues& x, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) {
|
DiscreteValues& DiscreteValues::insert(
|
||||||
for (const auto& kv : values) {
|
const std::pair<Key, size_t>& assignment) {
|
||||||
if (count(kv.first)) {
|
if (count(assignment.first)) {
|
||||||
throw std::out_of_range(
|
throw std::out_of_range(
|
||||||
"Requested to insert a DiscreteValues into another DiscreteValues "
|
"Requested to insert a DiscreteValues into another DiscreteValues "
|
||||||
"that already contains one or more of its keys.");
|
"that already contains one or more of its keys.");
|
||||||
} else {
|
} else {
|
||||||
this->emplace(kv);
|
this->emplace(assignment);
|
||||||
}
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues& DiscreteValues::insert(const DiscreteValues& values) {
|
||||||
|
for (const auto& kv : values) {
|
||||||
|
this->insert(kv);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,12 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
|
||||||
return Base::insert(value);
|
return Base::insert(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Insert key-assignment pair.
|
||||||
|
* Throws an invalid_argument exception if
|
||||||
|
* any keys to be inserted are already used. */
|
||||||
|
DiscreteValues& insert(const std::pair<Key, size_t>& assignment);
|
||||||
|
|
||||||
/** Insert all values from \c values. Throws an invalid_argument exception if
|
/** Insert all values from \c values. Throws an invalid_argument exception if
|
||||||
* any keys to be inserted are already used. */
|
* any keys to be inserted are already used. */
|
||||||
DiscreteValues& insert(const DiscreteValues& values);
|
DiscreteValues& insert(const DiscreteValues& values);
|
||||||
|
|
|
@ -124,7 +124,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
DiscreteValues HybridBayesNet::mpe() const {
|
||||||
// Collect all the discrete factors to compute MPE
|
// Collect all the discrete factors to compute MPE
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg;
|
||||||
|
|
||||||
|
@ -140,9 +140,13 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return discrete_fg.optimize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Solve for the MPE
|
// Solve for the MPE
|
||||||
DiscreteValues mpe = discrete_fg.optimize();
|
DiscreteValues mpe = this->mpe();
|
||||||
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
// Given the MPE, compute the optimal continuous values.
|
||||||
return HybridValues(optimize(mpe), mpe);
|
return HybridValues(optimize(mpe), mpe);
|
||||||
|
|
|
@ -146,6 +146,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
return evaluate(values);
|
return evaluate(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the Most Probable Explanation (MPE)
|
||||||
|
* of the discrete variables.
|
||||||
|
*
|
||||||
|
* @return DiscreteValues
|
||||||
|
*/
|
||||||
|
DiscreteValues mpe() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||||
* discrete variables and then optimizing the continuous variables based on
|
* discrete variables and then optimizing the continuous variables based on
|
||||||
|
|
|
@ -59,7 +59,7 @@ DiscreteValues HybridBayesTree::discreteMaxProduct(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesTree::optimize() const {
|
DiscreteValues HybridBayesTree::mpe() const {
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg;
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
|
||||||
|
@ -73,11 +73,16 @@ HybridValues HybridBayesTree::optimize() const {
|
||||||
discrete_fg.push_back(discrete);
|
discrete_fg.push_back(discrete);
|
||||||
mpe = discreteMaxProduct(discrete_fg);
|
mpe = discreteMaxProduct(discrete_fg);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
mpe = DiscreteValues();
|
||||||
"HybridBayesTree root is not discrete-only. Please check elimination "
|
|
||||||
"ordering or use continuous factor graph.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return mpe;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridBayesTree::optimize() const {
|
||||||
|
DiscreteValues mpe = this->mpe();
|
||||||
|
|
||||||
VectorValues values = optimize(mpe);
|
VectorValues values = optimize(mpe);
|
||||||
return HybridValues(values, mpe);
|
return HybridValues(values, mpe);
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,6 +105,14 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues& assignment) const;
|
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the Most Probable Explanation (MPE)
|
||||||
|
* of the discrete variables.
|
||||||
|
*
|
||||||
|
* @return DiscreteValues
|
||||||
|
*/
|
||||||
|
DiscreteValues mpe() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the underlying Bayes tree.
|
* @brief Prune the underlying Bayes tree.
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue