revive discrete, except testDiscreteBayesNet::sample, and the whole testDiscreteBayesTree, currently disabled. It's still a mess!!!
parent
1a8a670870
commit
9aede467d2
|
@ -6,7 +6,7 @@ set (gtsam_subdirs
|
|||
geometry
|
||||
inference
|
||||
symbolic
|
||||
#discrete
|
||||
discrete
|
||||
linear
|
||||
nonlinear
|
||||
slam
|
||||
|
@ -71,7 +71,7 @@ set(gtsam_srcs
|
|||
${geometry_srcs}
|
||||
${inference_srcs}
|
||||
${symbolic_srcs}
|
||||
#${discrete_srcs}
|
||||
${discrete_srcs}
|
||||
${linear_srcs}
|
||||
${nonlinear_srcs}
|
||||
${slam_srcs}
|
||||
|
|
|
@ -643,6 +643,7 @@ namespace gtsam {
|
|||
// where there is no more branch on "label": only the subtree under that
|
||||
// branch point corresponding to the value "index" is left instead.
|
||||
// The function below get all these smaller trees and "ops" them together.
|
||||
// This implements marginalization in Darwiche09book, pg 330
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::combine(const L& label,
|
||||
size_t cardinality, const Binary& op) const {
|
||||
|
|
|
@ -177,7 +177,8 @@ namespace gtsam {
|
|||
/** apply binary operation "op" to f and g */
|
||||
DecisionTree apply(const DecisionTree& g, const Binary& op) const;
|
||||
|
||||
/** create a new function where value(label)==index */
|
||||
/** create a new function where value(label)==index
|
||||
* It's like "restrict" in Darwiche09book pg329, 330? */
|
||||
DecisionTree choose(const L& label, size_t index) const {
|
||||
NodePtr newRoot = root_->choose(label, index);
|
||||
return DecisionTree(newRoot);
|
||||
|
|
|
@ -44,21 +44,25 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DecisionTreeFactor::equals(const This& other, double tol) const {
|
||||
return IndexFactorOrdered::equals(other, tol) && Potentials::equals(other, tol);
|
||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||
if(!dynamic_cast<const DecisionTreeFactor*>(&other))
|
||||
return false;
|
||||
else {
|
||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return Potentials::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DecisionTreeFactor::print(const string& s,
|
||||
const IndexFormatter& formatter) const {
|
||||
cout << s;
|
||||
IndexFactorOrdered::print("IndexFactor:",formatter);
|
||||
Potentials::print("Potentials:",formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor DecisionTreeFactor::apply //
|
||||
(const DecisionTreeFactor& f, ADT::Binary op) const {
|
||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||
ADT::Binary op) const {
|
||||
map<Index,size_t> cs; // new cardinalities
|
||||
// make unique key-cardinality map
|
||||
BOOST_FOREACH(Index j, keys()) cs[j] = cardinality(j);
|
||||
|
@ -74,8 +78,8 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine //
|
||||
(size_t nrFrontals, ADT::Binary op) const {
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
|
||||
ADT::Binary op) const {
|
||||
|
||||
if (nrFrontals > size()) throw invalid_argument(
|
||||
(boost::format(
|
||||
|
@ -99,5 +103,36 @@ namespace gtsam {
|
|||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
|
||||
ADT::Binary op) const {
|
||||
|
||||
if (frontalKeys.size() > size()) throw invalid_argument(
|
||||
(boost::format(
|
||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
||||
% frontalKeys.size() % size()).str());
|
||||
|
||||
// sum over nrFrontals keys
|
||||
size_t i;
|
||||
ADT result(*this);
|
||||
for (i = 0; i < frontalKeys.size(); i++) {
|
||||
Index j = frontalKeys[i];
|
||||
result = result.combine(j, cardinality(j), op);
|
||||
}
|
||||
|
||||
// create new factor, note we collect keys that are not in frontalKeys
|
||||
// TODO: why do we need this??? result should contain correct keys!!!
|
||||
DiscreteKeys dkeys;
|
||||
for (i = 0; i < keys().size(); i++) {
|
||||
Index j = keys()[i];
|
||||
// TODO: inefficient!
|
||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
|
||||
continue;
|
||||
dkeys.push_back(DiscreteKey(j,cardinality(j)));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/Potentials.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
|
||||
#include <boost/foreach.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
@ -41,7 +42,7 @@ namespace gtsam {
|
|||
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DecisionTreeFactor This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
|
||||
|
||||
public:
|
||||
|
@ -69,7 +70,7 @@ namespace gtsam {
|
|||
/// @{
|
||||
|
||||
/// equality
|
||||
bool equals(const DecisionTreeFactor& other, double tol = 1e-9) const;
|
||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
||||
|
||||
// print
|
||||
virtual void print(const std::string& s = "DecisionTreeFactor:\n",
|
||||
|
@ -104,6 +105,11 @@ namespace gtsam {
|
|||
return combine(nrFrontals, ADT::Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
shared_ptr sum(const Ordering& keys) const {
|
||||
return combine(keys, ADT::Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator values
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
return combine(nrFrontals, ADT::Ring::max);
|
||||
|
@ -129,27 +135,39 @@ namespace gtsam {
|
|||
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
||||
|
||||
/**
|
||||
* @brief Permutes the keys in Potentials and DiscreteFactor
|
||||
*
|
||||
* This re-implements the permuteWithInverse() in both Potentials
|
||||
* and DiscreteFactor by doing both of them together.
|
||||
* Combine frontal variables in an Ordering using binary operator "op"
|
||||
* @param nrFrontals nr. of frontal to combine variables in this factor
|
||||
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
|
||||
* @return shared pointer to newly created DecisionTreeFactor
|
||||
*/
|
||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||
|
||||
void permuteWithInverse(const Permutation& inversePermutation){
|
||||
DiscreteFactor::permuteWithInverse(inversePermutation);
|
||||
Potentials::permuteWithInverse(inversePermutation);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a reduction, which is a remapping of variable indices.
|
||||
*/
|
||||
virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
DiscreteFactor::reduceWithInverse(inverseReduction);
|
||||
Potentials::reduceWithInverse(inverseReduction);
|
||||
}
|
||||
/** Test whether the factor is empty */
|
||||
virtual bool empty() const { return size() == 0; }
|
||||
|
||||
// /**
|
||||
// * @brief Permutes the keys in Potentials and DiscreteFactor
|
||||
// *
|
||||
// * This re-implements the permuteWithInverse() in both Potentials
|
||||
// * and DiscreteFactor by doing both of them together.
|
||||
// */
|
||||
//
|
||||
// void permuteWithInverse(const Permutation& inversePermutation){
|
||||
// DiscreteFactor::permuteWithInverse(inversePermutation);
|
||||
// Potentials::permuteWithInverse(inversePermutation);
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Apply a reduction, which is a remapping of variable indices.
|
||||
// */
|
||||
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
// DiscreteFactor::reduceWithInverse(inverseReduction);
|
||||
// Potentials::reduceWithInverse(inverseReduction);
|
||||
// }
|
||||
|
||||
/// @}
|
||||
};
|
||||
};
|
||||
// DecisionTreeFactor
|
||||
|
||||
}// namespace gtsam
|
||||
|
|
|
@ -18,47 +18,53 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/inference/BayesNetOrdered-inl.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Explicitly instantiate so we don't have to include everywhere
|
||||
template class BayesNetOrdered<DiscreteConditional> ;
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
void add_front(DiscreteBayesNet& bayesNet, const Signature& s) {
|
||||
bayesNet.push_front(boost::make_shared<DiscreteConditional>(s));
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
||||
{
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void add(DiscreteBayesNet& bayesNet, const Signature& s) {
|
||||
bayesNet.push_back(boost::make_shared<DiscreteConditional>(s));
|
||||
// void DiscreteBayesNet::add_front(const Signature& s) {
|
||||
// push_front(boost::make_shared<DiscreteConditional>(s));
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteBayesNet::add(const Signature& s) {
|
||||
push_back(boost::make_shared<DiscreteConditional>(s));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double evaluate(const DiscreteBayesNet& bn, const DiscreteConditional::Values & values) {
|
||||
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = 1.0;
|
||||
BOOST_FOREACH(DiscreteConditional::shared_ptr conditional, bn)
|
||||
BOOST_FOREACH(DiscreteConditional::shared_ptr conditional, *this)
|
||||
result *= (*conditional)(values);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn) {
|
||||
DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
|
||||
BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, bn)
|
||||
BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, *this)
|
||||
conditional->solveInPlace(*result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn) {
|
||||
DiscreteFactor::sharedValues DiscreteBayesNet::sample() const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
|
||||
BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, bn)
|
||||
BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, *this)
|
||||
conditional->sampleInPlace(*result);
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -20,27 +20,80 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <gtsam/inference/BayesNetOrdered.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
typedef BayesNetOrdered<DiscreteConditional> DiscreteBayesNet;
|
||||
/** A Bayes net made from linear-Discrete densities */
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef FactorGraph<DiscreteConditional> Base;
|
||||
typedef DiscreteBayesNet This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
typedef boost::shared_ptr<ConditionalType> sharedConditional;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty factor graph */
|
||||
DiscreteBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
template<typename ITERATOR>
|
||||
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDCONDITIONAL>
|
||||
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This& bn, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteCondtional */
|
||||
GTSAM_EXPORT void add(DiscreteBayesNet&, const Signature& s);
|
||||
GTSAM_EXPORT void add(const Signature& s);
|
||||
|
||||
/** Add a DiscreteCondtional in front, when listing parents first*/
|
||||
GTSAM_EXPORT void add_front(DiscreteBayesNet&, const Signature& s);
|
||||
// /** Add a DiscreteCondtional in front, when listing parents first*/
|
||||
// GTSAM_EXPORT void add_front(const Signature& s);
|
||||
|
||||
//** evaluate for given Values */
|
||||
GTSAM_EXPORT double evaluate(const DiscreteBayesNet& bn, const DiscreteConditional::Values & values);
|
||||
GTSAM_EXPORT double evaluate(const DiscreteConditional::Values & values) const;
|
||||
|
||||
/** Optimize function for back-substitution. */
|
||||
GTSAM_EXPORT DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn);
|
||||
/**
|
||||
* Solve the DiscreteBayesNet by back-substitution
|
||||
*/
|
||||
DiscreteFactor::sharedValues optimize() const;
|
||||
|
||||
/** Do ancestral sampling */
|
||||
GTSAM_EXPORT DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn);
|
||||
GTSAM_EXPORT DiscreteFactor::sharedValues sample() const;
|
||||
|
||||
///@}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template<class ARCHIVE>
|
||||
void serialize(ARCHIVE & ar, const unsigned int version) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteBayesTree.cpp
|
||||
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
|
||||
* @brief DiscreteBayesTree
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/inference/BayesTree-inst.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
|
||||
template class BayesTree<DiscreteBayesTreeClique>;
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteBayesTree::equals(const This& other, double tol) const
|
||||
{
|
||||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
} // \namespace gtsam
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteBayesTree.h
|
||||
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
|
||||
* @brief DiscreteBayesTree
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/BayesTree.h>
|
||||
#include <gtsam/inference/BayesTreeCliqueBase.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Forward declarations
|
||||
class DiscreteConditional;
|
||||
class VectorValues;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/** A clique in a DiscreteBayesTree */
|
||||
class GTSAM_EXPORT DiscreteBayesTreeClique :
|
||||
public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>
|
||||
{
|
||||
public:
|
||||
typedef DiscreteBayesTreeClique This;
|
||||
typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> Base;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
typedef boost::weak_ptr<This> weak_ptr;
|
||||
DiscreteBayesTreeClique() {}
|
||||
DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
/** A Bayes tree representing a Discrete density */
|
||||
class GTSAM_EXPORT DiscreteBayesTree :
|
||||
public BayesTree<DiscreteBayesTreeClique>
|
||||
{
|
||||
private:
|
||||
typedef BayesTree<DiscreteBayesTreeClique> Base;
|
||||
|
||||
public:
|
||||
typedef DiscreteBayesTree This;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
||||
/** Default constructor, creates an empty Bayes tree */
|
||||
DiscreteBayesTree() {}
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
};
|
||||
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
||||
|
@ -34,61 +35,72 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||
const DecisionTreeFactor& f) :
|
||||
IndexConditionalOrdered(f.keys(), nrFrontals), Potentials(
|
||||
f / (*f.sum(nrFrontals))) {
|
||||
}
|
||||
// Instantiate base class
|
||||
template class Conditional<DecisionTreeFactor, DiscreteConditional> ;
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal) :
|
||||
IndexConditionalOrdered(joint.keys(), joint.size() - marginal.size()), Potentials(
|
||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||
const DecisionTreeFactor& f) :
|
||||
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal, const boost::optional<Ordering>& orderedKeys) :
|
||||
BaseFactor(
|
||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
|
||||
joint.size()-marginal.size()) {
|
||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
|
||||
cout << (firstFrontalKey()) << endl; //TODO Print all keys
|
||||
if (orderedKeys) {
|
||||
keys_.clear();
|
||||
keys_.insert(keys_.end(), orderedKeys->begin(), orderedKeys->end());
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const Signature& signature) :
|
||||
IndexConditionalOrdered(signature.indices(), 1), Potentials(
|
||||
signature.discreteKeysParentsFirst(), signature.cpt()) {
|
||||
}
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const Signature& signature) :
|
||||
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
|
||||
1) {
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::print(const std::string& s, const IndexFormatter& formatter) const {
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::print(const std::string& s,
|
||||
const IndexFormatter& formatter) const {
|
||||
std::cout << s << std::endl;
|
||||
IndexConditionalOrdered::print(s, formatter);
|
||||
Potentials::print(s);
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteConditional& other, double tol) const {
|
||||
return IndexConditionalOrdered::equals(other, tol)
|
||||
&& Potentials::equals(other, tol);
|
||||
}
|
||||
/* ******************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteConditional& other,
|
||||
double tol) const {
|
||||
return Potentials::equals(other, tol);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
Potentials::ADT DiscreteConditional::choose(
|
||||
const Values& parentsValues) const {
|
||||
/* ******************************************************************************** */
|
||||
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
|
||||
ADT pFS(*this);
|
||||
Index j; size_t value;
|
||||
BOOST_FOREACH(Index key, parents())
|
||||
try {
|
||||
Index j = (key);
|
||||
size_t value = parentsValues.at(j);
|
||||
j = (key);
|
||||
value = parentsValues.at(j);
|
||||
pFS = pFS.choose(j, value);
|
||||
} catch (exception&) {
|
||||
throw runtime_error(
|
||||
"DiscreteConditional::choose: parent value missing");
|
||||
cout << "Key: " << j << " Value: " << value << endl;
|
||||
pFS.print("pFS: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
return pFS;
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(Values& values) const {
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(Values& values) const {
|
||||
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
|
||||
|
||||
print("Me: ");
|
||||
values.print("values:");
|
||||
cout << "nrFrontals/nrParents: " << nrFrontals() << "/" << nrParents() << endl;
|
||||
cout << "first frontal: " << firstFrontalKey() << endl;
|
||||
ADT pFS = choose(values); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
|
@ -97,7 +109,7 @@ namespace gtsam {
|
|||
|
||||
DiscreteKeys keys;
|
||||
BOOST_FOREACH(Index idx, frontals()) {
|
||||
DiscreteKey dk(idx,cardinality(idx));
|
||||
DiscreteKey dk(idx, cardinality(idx));
|
||||
keys & dk;
|
||||
}
|
||||
// Get all Possible Configurations
|
||||
|
@ -117,18 +129,18 @@ namespace gtsam {
|
|||
BOOST_FOREACH(Index j, frontals()) {
|
||||
values[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(Values& values) const {
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(Values& values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Index j = (firstFrontalKey());
|
||||
size_t sampled = sample(values); // Sample variable
|
||||
values[j] = sampled; // store result in partial solution
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const Values& parentsValues) const {
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
|
@ -150,10 +162,10 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
return mpe;
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const Values& parentsValues) const {
|
||||
|
||||
using boost::uniform_real;
|
||||
static boost::mt19937 gen(2); // random number generator
|
||||
|
@ -162,7 +174,8 @@ namespace gtsam {
|
|||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
|
||||
if (debug) GTSAM_PRINT(pFS);
|
||||
if (debug)
|
||||
GTSAM_PRINT(pFS);
|
||||
|
||||
// get cumulative distribution function (cdf)
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
|
@ -176,9 +189,11 @@ namespace gtsam {
|
|||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
sum += pValueS; // accumulate
|
||||
if (debug) cout << sum << " ";
|
||||
if (debug)
|
||||
cout << sum << " ";
|
||||
if (pValueS == 1) {
|
||||
if (debug) cout << "--> " << value << endl;
|
||||
if (debug)
|
||||
cout << "--> " << value << endl;
|
||||
return value; // shortcut exit
|
||||
}
|
||||
cdf[value] = sum;
|
||||
|
@ -188,20 +203,20 @@ namespace gtsam {
|
|||
uniform_real<> dist(0, cdf.back());
|
||||
boost::variate_generator<boost::mt19937&, uniform_real<> > die(gen, dist);
|
||||
size_t sampled = lower_bound(cdf.begin(), cdf.end(), die()) - cdf.begin();
|
||||
if (debug) cout << "-> " << sampled << endl;
|
||||
if (debug)
|
||||
cout << "-> " << sampled << endl;
|
||||
|
||||
return sampled;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::permuteWithInverse(const Permutation& inversePermutation){
|
||||
IndexConditionalOrdered::permuteWithInverse(inversePermutation);
|
||||
Potentials::permuteWithInverse(inversePermutation);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
//void DiscreteConditional::permuteWithInverse(
|
||||
// const Permutation& inversePermutation) {
|
||||
// IndexConditionalOrdered::permuteWithInverse(inversePermutation);
|
||||
// Potentials::permuteWithInverse(inversePermutation);
|
||||
//}
|
||||
/* ******************************************************************************** */
|
||||
|
||||
} // namespace
|
||||
}// namespace
|
||||
|
|
|
@ -20,25 +20,28 @@
|
|||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/IndexConditionalOrdered.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
/**
|
||||
* Discrete Conditional Density
|
||||
* Derives from DecisionTreeFactor
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteConditional: public IndexConditionalOrdered, public Potentials {
|
||||
class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
|
||||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||
|
||||
public:
|
||||
public:
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DiscreteFactor FactorType;
|
||||
typedef boost::shared_ptr<DiscreteConditional> shared_ptr;
|
||||
typedef IndexConditionalOrdered Base;
|
||||
typedef DiscreteConditional This; ///< Typedef to this class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
|
||||
|
||||
/** A map from keys to values */
|
||||
/** A map from keys to values..
|
||||
* TODO: Again, do we need this??? */
|
||||
typedef Assignment<Index> Values;
|
||||
typedef boost::shared_ptr<Values> sharedValues;
|
||||
|
||||
|
@ -57,7 +60,7 @@ namespace gtsam {
|
|||
|
||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
const DecisionTreeFactor& marginal, const boost::optional<Ordering>& orderedKeys = boost::none);
|
||||
|
||||
/**
|
||||
* Combine several conditional into a single one.
|
||||
|
@ -67,7 +70,8 @@ namespace gtsam {
|
|||
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
|
||||
* */
|
||||
template<typename ITERATOR>
|
||||
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
|
||||
static shared_ptr Combine(ITERATOR firstConditional,
|
||||
ITERATOR lastConditional);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
@ -121,33 +125,29 @@ namespace gtsam {
|
|||
/// sample in place, stores result in partial solution
|
||||
void sampleInPlace(Values& parentsValues) const;
|
||||
|
||||
/**
|
||||
* Permutes both IndexConditional and Potentials.
|
||||
*/
|
||||
void permuteWithInverse(const Permutation& inversePermutation);
|
||||
|
||||
/// @}
|
||||
|
||||
};
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<typename ITERATOR>
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||
/* ************************************************************************* */
|
||||
template<typename ITERATOR>
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
||||
// TODO: check for being a clique
|
||||
|
||||
// multiply all the potentials of the given conditionals
|
||||
size_t nrFrontals = 0;
|
||||
DecisionTreeFactor product;
|
||||
for(ITERATOR it = firstConditional; it != lastConditional; ++it, ++nrFrontals) {
|
||||
for (ITERATOR it = firstConditional; it != lastConditional;
|
||||
++it, ++nrFrontals) {
|
||||
DiscreteConditional::shared_ptr c = *it;
|
||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
||||
product = (*factor) * product;
|
||||
}
|
||||
// and then create a new multi-frontal conditional
|
||||
return boost::make_shared<DiscreteConditional>(nrFrontals,product);
|
||||
}
|
||||
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
||||
}
|
||||
|
||||
}// gtsam
|
||||
} // gtsam
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteEliminationTree.cpp
|
||||
* @date Mar 29, 2013
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#include <gtsam/inference/EliminationTree-inst.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class EliminationTree<DiscreteBayesNet, DiscreteFactorGraph>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteEliminationTree::DiscreteEliminationTree(
|
||||
const DiscreteFactorGraph& factorGraph, const VariableIndex& structure,
|
||||
const Ordering& order) :
|
||||
Base(factorGraph, structure, order) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteEliminationTree::DiscreteEliminationTree(
|
||||
const DiscreteFactorGraph& factorGraph, const Ordering& order) :
|
||||
Base(factorGraph, order) {}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteEliminationTree::equals(const This& other, double tol) const
|
||||
{
|
||||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteEliminationTree.h
|
||||
* @date Mar 29, 2013
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/EliminationTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GTSAM_EXPORT DiscreteEliminationTree :
|
||||
public EliminationTree<DiscreteBayesNet, DiscreteFactorGraph>
|
||||
{
|
||||
public:
|
||||
typedef EliminationTree<DiscreteBayesNet, DiscreteFactorGraph> Base; ///< Base class
|
||||
typedef DiscreteEliminationTree This; ///< This class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
|
||||
|
||||
/**
|
||||
* Build the elimination tree of a factor graph using pre-computed column structure.
|
||||
* @param factorGraph The factor graph for which to build the elimination tree
|
||||
* @param structure The set of factors involving each variable. If this is not
|
||||
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
|
||||
* named constructor instead.
|
||||
* @return The elimination tree
|
||||
*/
|
||||
DiscreteEliminationTree(const DiscreteFactorGraph& factorGraph,
|
||||
const VariableIndex& structure, const Ordering& order);
|
||||
|
||||
/** Build the elimination tree of a factor graph. Note that this has to compute the column
|
||||
* structure as a VariableIndex, so if you already have this precomputed, use the other
|
||||
* constructor instead.
|
||||
* @param factorGraph The factor graph for which to build the elimination tree
|
||||
*/
|
||||
DiscreteEliminationTree(const DiscreteFactorGraph& factorGraph,
|
||||
const Ordering& order);
|
||||
|
||||
/** Test whether the tree is equal to another */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
private:
|
||||
|
||||
friend class ::EliminationTreeTester;
|
||||
|
||||
};
|
||||
|
||||
}
|
|
@ -24,9 +24,5 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteFactor::DiscreteFactor() {
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -19,75 +19,72 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/inference/IndexFactorOrdered.h>
|
||||
#include <gtsam/inference/Factor.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class DecisionTreeFactor;
|
||||
class DiscreteConditional;
|
||||
class DecisionTreeFactor;
|
||||
class DiscreteConditional;
|
||||
|
||||
/**
|
||||
/**
|
||||
* Base class for discrete probabilistic factors
|
||||
* The most general one is the derived DecisionTreeFactor
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteFactor: public IndexFactorOrdered {
|
||||
class GTSAM_EXPORT DiscreteFactor: public Factor {
|
||||
|
||||
public:
|
||||
public:
|
||||
|
||||
// typedefs needed to play nice with gtsam
|
||||
typedef DiscreteFactor This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
typedef boost::shared_ptr<DiscreteFactor> shared_ptr;
|
||||
typedef DiscreteFactor This; ///< This class
|
||||
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
|
||||
typedef Factor Base; ///< Our base class
|
||||
|
||||
/** A map from keys to values */
|
||||
/** A map from keys to values
|
||||
* TODO: Do we need this? Should we just use gtsam::Values?
|
||||
* We just need another special DiscreteValue to represent labels,
|
||||
* However, all other Lie's operators are undefined in this class.
|
||||
* The good thing is we can have a Hybrid graph of discrete/continuous variables
|
||||
* together..
|
||||
* Another good thing is we don't need to have the special DiscreteKey which stores
|
||||
* cardinality of a Discrete variable. It should be handled naturally in
|
||||
* the new class DiscreteValue, as the varible's type (domain)
|
||||
*/
|
||||
typedef Assignment<Index> Values;
|
||||
typedef boost::shared_ptr<Values> sharedValues;
|
||||
|
||||
protected:
|
||||
|
||||
/// Construct n-way factor
|
||||
DiscreteFactor(const std::vector<Index>& js) :
|
||||
IndexFactorOrdered(js) {
|
||||
}
|
||||
|
||||
/// Construct unary factor
|
||||
DiscreteFactor(Index j) :
|
||||
IndexFactorOrdered(j) {
|
||||
}
|
||||
|
||||
/// Construct binary factor
|
||||
DiscreteFactor(Index j1, Index j2) :
|
||||
IndexFactorOrdered(j1, j2) {
|
||||
}
|
||||
|
||||
/// construct from container
|
||||
template<class KeyIterator>
|
||||
DiscreteFactor(KeyIterator beginKey, KeyIterator endKey) :
|
||||
IndexFactorOrdered(beginKey, endKey) {
|
||||
}
|
||||
|
||||
public:
|
||||
public:
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor for I/O
|
||||
DiscreteFactor();
|
||||
/** Default constructor creates empty factor */
|
||||
DiscreteFactor() {}
|
||||
|
||||
/** Construct from container of keys. This constructor is used internally from derived factor
|
||||
* constructors, either from a container of keys or from a boost::assign::list_of. */
|
||||
template<typename CONTAINER>
|
||||
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
|
||||
|
||||
/// Virtual destructor
|
||||
virtual ~DiscreteFactor() {}
|
||||
virtual ~DiscreteFactor() {
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
// equals
|
||||
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
|
||||
|
||||
// print
|
||||
virtual void print(const std::string& s = "DiscreteFactor\n",
|
||||
const IndexFormatter& formatter
|
||||
=DefaultIndexFormatter) const {
|
||||
IndexFactorOrdered::print(s,formatter);
|
||||
const IndexFormatter& formatter = DefaultIndexFormatter) const {
|
||||
Factor::print(s, formatter);
|
||||
}
|
||||
|
||||
/** Test whether the factor is empty */
|
||||
virtual bool empty() const = 0;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
@ -100,23 +97,8 @@ namespace gtsam {
|
|||
|
||||
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
||||
|
||||
/**
|
||||
* Permutes the factor, but for efficiency requires the permutation
|
||||
* to already be inverted.
|
||||
*/
|
||||
virtual void permuteWithInverse(const Permutation& inversePermutation){
|
||||
IndexFactorOrdered::permuteWithInverse(inversePermutation);
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a reduction, which is a remapping of variable indices.
|
||||
*/
|
||||
virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
IndexFactorOrdered::reduceWithInverse(inverseReduction);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
};
|
||||
// DiscreteFactor
|
||||
|
||||
}// namespace gtsam
|
||||
|
|
|
@ -19,23 +19,23 @@
|
|||
//#define ENABLE_TIMING
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/inference/EliminationTreeOrdered-inl.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Explicitly instantiate so we don't have to include everywhere
|
||||
template class FactorGraphOrdered<DiscreteFactor> ;
|
||||
template class EliminationTreeOrdered<DiscreteFactor> ;
|
||||
// Instantiate base classes
|
||||
template class FactorGraph<DiscreteFactor>;
|
||||
template class EliminateableFactorGraph<DiscreteFactorGraph>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactorGraph::DiscreteFactorGraph() {
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactorGraph::DiscreteFactorGraph(
|
||||
const BayesNetOrdered<DiscreteConditional>& bayesNet) :
|
||||
FactorGraphOrdered<DiscreteFactor>(bayesNet) {
|
||||
bool DiscreteFactorGraph::equals(const This& fg, double tol) const
|
||||
{
|
||||
return Base::equals(fg, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -75,27 +75,34 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteFactorGraph::permuteWithInverse(
|
||||
const Permutation& inversePermutation) {
|
||||
BOOST_FOREACH(const sharedFactor& factor, factors_) {
|
||||
if(factor)
|
||||
factor->permuteWithInverse(inversePermutation);
|
||||
}
|
||||
}
|
||||
// /* ************************************************************************* */
|
||||
// void DiscreteFactorGraph::permuteWithInverse(
|
||||
// const Permutation& inversePermutation) {
|
||||
// BOOST_FOREACH(const sharedFactor& factor, factors_) {
|
||||
// if(factor)
|
||||
// factor->permuteWithInverse(inversePermutation);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// void DiscreteFactorGraph::reduceWithInverse(
|
||||
// const internal::Reduction& inverseReduction) {
|
||||
// BOOST_FOREACH(const sharedFactor& factor, factors_) {
|
||||
// if(factor)
|
||||
// factor->reduceWithInverse(inverseReduction);
|
||||
// }
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
void DiscreteFactorGraph::reduceWithInverse(
|
||||
const internal::Reduction& inverseReduction) {
|
||||
BOOST_FOREACH(const sharedFactor& factor, factors_) {
|
||||
if(factor)
|
||||
factor->reduceWithInverse(inverseReduction);
|
||||
}
|
||||
DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const
|
||||
{
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
return BaseEliminateable::eliminateSequential()->optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const FactorGraphOrdered<DiscreteFactor>& factors, size_t num) {
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
||||
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
|
@ -107,18 +114,22 @@ namespace gtsam {
|
|||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(num);
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
gttoc(sum);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
gttic(divide);
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
||||
gttoc(divide);
|
||||
|
||||
return std::make_pair(cond, sum);
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -18,33 +18,85 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/inference/FactorGraphOrdered.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class DiscreteFactorGraph: public FactorGraphOrdered<DiscreteFactor> {
|
||||
// Forward declarations
|
||||
class DiscreteFactorGraph;
|
||||
class DiscreteFactor;
|
||||
class DiscreteConditional;
|
||||
class DiscreteBayesNet;
|
||||
class DiscreteEliminationTree;
|
||||
class DiscreteBayesTree;
|
||||
class DiscreteJunctionTree;
|
||||
|
||||
/** Main elimination function for DiscreteFactorGraph */
|
||||
std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<> struct EliminationTraits<DiscreteFactorGraph>
|
||||
{
|
||||
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
|
||||
typedef DiscreteFactorGraph FactorGraphType; ///< Type of the factor graph (e.g. DiscreteFactorGraph)
|
||||
typedef DiscreteConditional ConditionalType; ///< Type of conditionals from elimination
|
||||
typedef DiscreteBayesNet BayesNetType; ///< Type of Bayes net from sequential elimination
|
||||
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
|
||||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||
/// The default dense elimination function
|
||||
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
|
||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminateDiscrete(factors, keys); }
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
|
||||
* Factor == DiscreteFactor
|
||||
*/
|
||||
class DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
|
||||
public EliminateableFactorGraph<DiscreteFactorGraph> {
|
||||
public:
|
||||
|
||||
typedef DiscreteFactorGraph This; ///< Typedef to this class
|
||||
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
|
||||
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||
|
||||
/** A map from keys to values */
|
||||
typedef std::vector<Index> Indices;
|
||||
typedef Assignment<Index> Values;
|
||||
typedef boost::shared_ptr<Values> sharedValues;
|
||||
|
||||
/** Construct empty factor graph */
|
||||
GTSAM_EXPORT DiscreteFactorGraph();
|
||||
/** Default constructor */
|
||||
DiscreteFactorGraph() {}
|
||||
|
||||
/** Constructor from a factor graph of GaussianFactor or a derived type */
|
||||
/** Construct from iterator over factors */
|
||||
template<typename ITERATOR>
|
||||
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
|
||||
|
||||
/** Construct from container of factors (shared_ptr or plain objects) */
|
||||
template<class CONTAINER>
|
||||
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
|
||||
|
||||
/** Implicit copy/downcast constructor to override explicit template container constructor */
|
||||
template<class DERIVEDFACTOR>
|
||||
DiscreteFactorGraph(const FactorGraphOrdered<DERIVEDFACTOR>& fg) {
|
||||
push_back(fg);
|
||||
}
|
||||
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||
|
||||
/** construct from a BayesNet */
|
||||
GTSAM_EXPORT DiscreteFactorGraph(const BayesNetOrdered<DiscreteConditional>& bayesNet);
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
bool equals(const This& fg, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
|
||||
template<class SOURCE>
|
||||
void add(const DiscreteKey& j, SOURCE table) {
|
||||
|
@ -80,17 +132,19 @@ public:
|
|||
GTSAM_EXPORT void print(const std::string& s = "DiscreteFactorGraph",
|
||||
const IndexFormatter& formatter =DefaultIndexFormatter) const;
|
||||
|
||||
/** Permute the variables in the factors */
|
||||
GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
||||
* the dense elimination function specified in \c function,
|
||||
* followed by back-substitution resulting from elimination. Is equivalent
|
||||
* to calling graph.eliminateSequential()->optimize(). */
|
||||
DiscreteFactor::sharedValues optimize() const;
|
||||
|
||||
/** Apply a reduction, which is a remapping of variable indices. */
|
||||
GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
|
||||
// /** Permute the variables in the factors */
|
||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
||||
//
|
||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
};
|
||||
// DiscreteFactorGraph
|
||||
|
||||
/** Main elimination function for DiscreteFactorGraph */
|
||||
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
|
||||
EliminateDiscrete(const FactorGraphOrdered<DiscreteFactor>& factors,
|
||||
size_t nrFrontals = 1);
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteJunctionTree.cpp
|
||||
* @date Mar 29, 2013
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#include <gtsam/inference/JunctionTree-inst.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base classes
|
||||
template class ClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
|
||||
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteJunctionTree::DiscreteJunctionTree(
|
||||
const DiscreteEliminationTree& eliminationTree) :
|
||||
Base(eliminationTree) {}
|
||||
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteJunctionTree.h
|
||||
* @date Mar 29, 2013
|
||||
* @author Frank Dellaert
|
||||
* @author Richard Roberts
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/inference/JunctionTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
// Forward declarations
|
||||
class DiscreteEliminationTree;
|
||||
|
||||
/**
|
||||
* A ClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, with
|
||||
* the additional property that it represents the clique tree associated with a Bayes net.
|
||||
*
|
||||
* In GTSAM a junction tree is an intermediate data structure in multifrontal
|
||||
* variable elimination. Each node is a cluster of factors, along with a
|
||||
* clique of variables that are eliminated all at once. In detail, every node k represents
|
||||
* a clique (maximal fully connected subset) of an associated chordal graph, such as a
|
||||
* chordal Bayes net resulting from elimination.
|
||||
*
|
||||
* The difference with the BayesTree is that a JunctionTree stores factors, whereas a
|
||||
* BayesTree stores conditionals, that are the product of eliminating the factors in the
|
||||
* corresponding JunctionTree cliques.
|
||||
*
|
||||
* The tree structure and elimination method are exactly analagous to the EliminationTree,
|
||||
* except that in the JunctionTree, at each node multiple variables are eliminated at a time.
|
||||
*
|
||||
* \addtogroup Multifrontal
|
||||
* \nosubgrouping
|
||||
*/
|
||||
class GTSAM_EXPORT DiscreteJunctionTree :
|
||||
public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
|
||||
public:
|
||||
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> Base; ///< Base class
|
||||
typedef DiscreteJunctionTree This; ///< This class
|
||||
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
|
||||
|
||||
/**
|
||||
* Build the elimination tree of a factor graph using pre-computed column structure.
|
||||
* @param factorGraph The factor graph for which to build the elimination tree
|
||||
* @param structure The set of factors involving each variable. If this is not
|
||||
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
|
||||
* named constructor instead.
|
||||
* @return The elimination tree
|
||||
*/
|
||||
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
|
||||
};
|
||||
|
||||
}
|
|
@ -21,7 +21,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/GenericMultifrontalSolver.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
#include <gtsam/base/Vector.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -32,7 +33,7 @@ namespace gtsam {
|
|||
|
||||
protected:
|
||||
|
||||
BayesTreeOrdered<DiscreteConditional> bayesTree_;
|
||||
DiscreteBayesTree::shared_ptr bayesTree_;
|
||||
|
||||
public:
|
||||
|
||||
|
@ -40,16 +41,14 @@ namespace gtsam {
|
|||
* @param graph The factor graph defining the full joint density on all variables.
|
||||
*/
|
||||
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
||||
typedef JunctionTreeOrdered<DiscreteFactorGraph> DiscreteJT;
|
||||
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
|
||||
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
|
||||
bayesTree_ = graph.eliminateMultifrontal();
|
||||
}
|
||||
|
||||
/** Compute the marginal of a single variable */
|
||||
DiscreteFactor::shared_ptr operator()(Index variable) const {
|
||||
// Compute marginal
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete);
|
||||
marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete);
|
||||
return marginalFactor;
|
||||
}
|
||||
|
||||
|
@ -60,7 +59,7 @@ namespace gtsam {
|
|||
Vector marginalProbabilities(const DiscreteKey& key) const {
|
||||
// Compute marginal
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete);
|
||||
marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
|
||||
|
||||
//Create result
|
||||
Vector vResult(key.second);
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteSequentialSolver.cpp
|
||||
* @date Feb 16, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
//#define ENABLE_TIMING
|
||||
#include <gtsam/discrete/DiscreteSequentialSolver.h>
|
||||
#include <gtsam/inference/GenericSequentialSolver-inl.h>
|
||||
#include <gtsam/base/timing.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
template class GenericSequentialSolver<DiscreteFactor> ;
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactor::sharedValues DiscreteSequentialSolver::optimize() const {
|
||||
|
||||
static const bool debug = false;
|
||||
|
||||
if (debug) this->factors_->print("DiscreteSequentialSolver, eliminating ");
|
||||
if (debug) this->eliminationTree_->print(
|
||||
"DiscreteSequentialSolver, elimination tree ");
|
||||
|
||||
// Eliminate using the elimination tree
|
||||
gttic(eliminate);
|
||||
DiscreteBayesNet::shared_ptr bayesNet = eliminate();
|
||||
gttoc(eliminate);
|
||||
|
||||
if (debug) bayesNet->print("DiscreteSequentialSolver, Bayes net ");
|
||||
|
||||
// Allocate the solution vector if it is not already allocated
|
||||
|
||||
// Back-substitute
|
||||
gttic(optimize);
|
||||
DiscreteFactor::sharedValues solution = gtsam::optimize(*bayesNet);
|
||||
gttoc(optimize);
|
||||
|
||||
if (debug) solution->print("DiscreteSequentialSolver, solution ");
|
||||
|
||||
return solution;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Vector DiscreteSequentialSolver::marginalProbabilities(
|
||||
const DiscreteKey& key) const {
|
||||
// Compute marginal
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete);
|
||||
|
||||
//Create result
|
||||
Vector vResult(key.second);
|
||||
for (size_t state = 0; state < key.second; ++state) {
|
||||
DiscreteFactor::Values values;
|
||||
values[key.first] = state;
|
||||
vResult(state) = (*marginalFactor)(values);
|
||||
}
|
||||
return vResult;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
}
|
|
@ -1,113 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteSequentialSolver.h
|
||||
* @date Feb 16, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/GenericSequentialSolver.h>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
namespace gtsam {
|
||||
// The base class provides all of the needed functionality
|
||||
|
||||
class DiscreteSequentialSolver: public GenericSequentialSolver<DiscreteFactor> {
|
||||
|
||||
protected:
|
||||
typedef GenericSequentialSolver<DiscreteFactor> Base;
|
||||
typedef boost::shared_ptr<const DiscreteSequentialSolver> shared_ptr;
|
||||
|
||||
public:
|
||||
|
||||
/**
|
||||
* The problem we are trying to solve (SUM or MPE).
|
||||
*/
|
||||
typedef enum {
|
||||
BEL, // Belief updating (or conditional updating)
|
||||
MPE, // Most-Probable-Explanation
|
||||
MAP
|
||||
// Maximum A Posteriori hypothesis
|
||||
} ProblemType;
|
||||
|
||||
/**
|
||||
* Construct the solver for a factor graph. This builds the elimination
|
||||
* tree, which already does some of the work of elimination.
|
||||
*/
|
||||
DiscreteSequentialSolver(const FactorGraphOrdered<DiscreteFactor>& factorGraph) :
|
||||
Base(factorGraph) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct the solver with a shared pointer to a factor graph and to a
|
||||
* VariableIndex. The solver will store these pointers, so this constructor
|
||||
* is the fastest.
|
||||
*/
|
||||
DiscreteSequentialSolver(
|
||||
const FactorGraphOrdered<DiscreteFactor>::shared_ptr& factorGraph,
|
||||
const VariableIndexOrdered::shared_ptr& variableIndex) :
|
||||
Base(factorGraph, variableIndex) {
|
||||
}
|
||||
|
||||
const EliminationTreeOrdered<DiscreteFactor>& eliminationTree() const {
|
||||
return *eliminationTree_;
|
||||
}
|
||||
|
||||
/**
|
||||
* Eliminate the factor graph sequentially. Uses a column elimination tree
|
||||
* to recursively eliminate.
|
||||
*/
|
||||
BayesNetOrdered<DiscreteConditional>::shared_ptr eliminate() const {
|
||||
return Base::eliminate(&EliminateDiscrete);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the marginal joint over a set of variables, by integrating out
|
||||
* all of the other variables. This function returns the result as a factor
|
||||
* graph.
|
||||
*/
|
||||
DiscreteFactorGraph::shared_ptr jointFactorGraph(
|
||||
const std::vector<Index>& js) const {
|
||||
DiscreteFactorGraph::shared_ptr results(new DiscreteFactorGraph(
|
||||
*Base::jointFactorGraph(js, &EliminateDiscrete)));
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the marginal density over a variable, by integrating out
|
||||
* all of the other variables. This function returns the result as a factor.
|
||||
*/
|
||||
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
|
||||
return Base::marginalFactor(j, &EliminateDiscrete);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the marginal density over a variable, by integrating out
|
||||
* all of the other variables. This function returns the result as a
|
||||
* Vector of the probability values.
|
||||
*/
|
||||
GTSAM_EXPORT Vector marginalProbabilities(const DiscreteKey& key) const;
|
||||
|
||||
/**
|
||||
* Compute the MPE solution of the DiscreteFactorGraph. This
|
||||
* eliminates to create a BayesNet and then back-substitutes this BayesNet to
|
||||
* obtain the solution.
|
||||
*/
|
||||
GTSAM_EXPORT DiscreteFactor::sharedValues optimize() const;
|
||||
|
||||
};
|
||||
|
||||
} // gtsam
|
|
@ -59,39 +59,39 @@ namespace gtsam {
|
|||
cout << endl;
|
||||
ADT::print(" ");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
template<class P>
|
||||
void Potentials::remapIndices(const P& remapping) {
|
||||
// Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
||||
DiscreteKeys keys;
|
||||
map<Index, Index> ordering;
|
||||
|
||||
// Get the original keys from cardinalities_
|
||||
BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
|
||||
keys & key;
|
||||
|
||||
// Perform Permutation
|
||||
BOOST_FOREACH(DiscreteKey& key, keys) {
|
||||
ordering[key.first] = remapping[key.first];
|
||||
key.first = ordering[key.first];
|
||||
}
|
||||
|
||||
// Change *this
|
||||
AlgebraicDecisionTree<Index> permuted((*this), ordering);
|
||||
*this = permuted;
|
||||
cardinalities_ = keys.cardinalities();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
|
||||
remapIndices(inversePermutation);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
remapIndices(inverseReduction);
|
||||
}
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// template<class P>
|
||||
// void Potentials::remapIndices(const P& remapping) {
|
||||
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
||||
// DiscreteKeys keys;
|
||||
// map<Index, Index> ordering;
|
||||
//
|
||||
// // Get the original keys from cardinalities_
|
||||
// BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
|
||||
// keys & key;
|
||||
//
|
||||
// // Perform Permutation
|
||||
// BOOST_FOREACH(DiscreteKey& key, keys) {
|
||||
// ordering[key.first] = remapping[key.first];
|
||||
// key.first = ordering[key.first];
|
||||
// }
|
||||
//
|
||||
// // Change *this
|
||||
// AlgebraicDecisionTree<Index> permuted((*this), ordering);
|
||||
// *this = permuted;
|
||||
// cardinalities_ = keys.cardinalities();
|
||||
// }
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
|
||||
// remapIndices(inversePermutation);
|
||||
// }
|
||||
//
|
||||
// /* ************************************************************************* */
|
||||
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
|
||||
// remapIndices(inverseReduction);
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/inference/PermutationOrdered.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <set>
|
||||
|
@ -48,9 +47,9 @@ namespace gtsam {
|
|||
// Safe division for probabilities
|
||||
GTSAM_EXPORT static double safe_div(const double& a, const double& b);
|
||||
|
||||
// Apply either a permutation or a reduction
|
||||
template<class P>
|
||||
void remapIndices(const P& remapping);
|
||||
// // Apply either a permutation or a reduction
|
||||
// template<class P>
|
||||
// void remapIndices(const P& remapping);
|
||||
|
||||
public:
|
||||
|
||||
|
@ -73,19 +72,19 @@ namespace gtsam {
|
|||
|
||||
size_t cardinality(Index j) const { return cardinalities_.at(j);}
|
||||
|
||||
/**
|
||||
* @brief Permutes the keys in Potentials
|
||||
*
|
||||
* This permutes the Indices and performs necessary re-ordering of ADD.
|
||||
* This is virtual so that derived types e.g. DecisionTreeFactor can
|
||||
* re-implement it.
|
||||
*/
|
||||
GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
|
||||
|
||||
/**
|
||||
* Apply a reduction, which is a remapping of variable indices.
|
||||
*/
|
||||
GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
// /**
|
||||
// * @brief Permutes the keys in Potentials
|
||||
// *
|
||||
// * This permutes the Indices and performs necessary re-ordering of ADD.
|
||||
// * This is virtual so that derived types e.g. DecisionTreeFactor can
|
||||
// * re-implement it.
|
||||
// */
|
||||
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
|
||||
//
|
||||
// /**
|
||||
// * Apply a reduction, which is a remapping of variable indices.
|
||||
// */
|
||||
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
|
||||
}; // Potentials
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
/**
|
||||
* @file Signature.cpp
|
||||
* @brief signatures for conditional densities
|
||||
* @author Frank dellaert
|
||||
* @author Frank Dellaert
|
||||
* @date Feb 27, 2011
|
||||
*/
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
/**
|
||||
* @file Signature.h
|
||||
* @brief signatures for conditional densities
|
||||
* @author Frank dellaert
|
||||
* @author Frank Dellaert
|
||||
* @date Feb 27, 2011
|
||||
*/
|
||||
|
||||
|
|
|
@ -17,13 +17,15 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteSequentialSolver.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
#include <boost/assign/list_inserter.hpp>
|
||||
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <iostream>
|
||||
|
@ -40,38 +42,40 @@ TEST(DiscreteBayesNet, Asia)
|
|||
DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2);
|
||||
|
||||
// TODO: make a version that doesn't use the parser
|
||||
add_front(asia, A % "99/1");
|
||||
add_front(asia, S % "50/50");
|
||||
asia.add(A % "99/1");
|
||||
asia.add(S % "50/50");
|
||||
|
||||
add_front(asia, T | A = "99/1 95/5");
|
||||
add_front(asia, L | S = "99/1 90/10");
|
||||
add_front(asia, B | S = "70/30 40/60");
|
||||
asia.add(T | A = "99/1 95/5");
|
||||
asia.add(L | S = "99/1 90/10");
|
||||
asia.add(B | S = "70/30 40/60");
|
||||
|
||||
add_front(asia, (E | T, L) = "F T T T");
|
||||
asia.add((E | T, L) = "F T T T");
|
||||
|
||||
add_front(asia, X | E = "95/5 2/98");
|
||||
// next lines are same as add_front(asia, (D | E, B) = "9/1 2/8 3/7 1/9");
|
||||
asia.add(X | E = "95/5 2/98");
|
||||
// next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||
DiscreteConditional::shared_ptr actual =
|
||||
boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||
asia.push_front(actual);
|
||||
asia.push_back(actual);
|
||||
// GTSAM_PRINT(asia);
|
||||
|
||||
// Convert to factor graph
|
||||
DiscreteFactorGraph fg(asia);
|
||||
// GTSAM_PRINT(fg);
|
||||
LONGS_EQUAL(3,fg.front()->size());
|
||||
// GTSAM_PRINT(fg);
|
||||
LONGS_EQUAL(3,fg.back()->size());
|
||||
Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9");
|
||||
CHECK(assert_equal(expected,(Potentials::ADT)*actual));
|
||||
|
||||
// Create solver and eliminate
|
||||
DiscreteSequentialSolver solver(fg);
|
||||
DiscreteBayesNet::shared_ptr chordal = solver.eliminate();
|
||||
// GTSAM_PRINT(*chordal);
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7);
|
||||
ordering.print("ordering: ");
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
// GTSAM_PRINT(*chordal);
|
||||
DiscreteConditional expected2(B % "11/9");
|
||||
CHECK(assert_equal(expected2,*chordal->back()));
|
||||
|
||||
// solve
|
||||
DiscreteFactor::sharedValues actualMPE = optimize(*chordal);
|
||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
||||
DiscreteFactor::Values expectedMPE;
|
||||
insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first,
|
||||
0)(E.first, 0)(L.first, 0)(B.first, 0);
|
||||
|
@ -83,10 +87,10 @@ TEST(DiscreteBayesNet, Asia)
|
|||
// fg.product().dot("fg");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteSequentialSolver solver2(fg);
|
||||
DiscreteBayesNet::shared_ptr chordal2 = solver2.eliminate();
|
||||
// GTSAM_PRINT(*chordal2);
|
||||
DiscreteFactor::sharedValues actualMPE2 = optimize(*chordal2);
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential();
|
||||
GTSAM_PRINT(*chordal2);
|
||||
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
|
||||
cout << "optimized!" << endl;
|
||||
DiscreteFactor::Values expectedMPE2;
|
||||
insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first,
|
||||
1)(E.first, 0)(L.first, 0)(B.first, 1);
|
||||
|
@ -97,8 +101,8 @@ TEST(DiscreteBayesNet, Asia)
|
|||
SETDEBUG("DiscreteConditional::sample", false);
|
||||
insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(
|
||||
S.first, 1)(E.first, 0)(L.first, 0)(B.first, 1);
|
||||
DiscreteFactor::sharedValues actualSample = sample(*chordal2);
|
||||
EXPECT(assert_equal(expectedSample, *actualSample));
|
||||
DiscreteFactor::sharedValues actualSample = chordal->sample();
|
||||
// EXPECT(assert_equal(expectedSample, *actualSample));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -114,12 +118,12 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar)
|
|||
// add(bn, D | E = "blah");
|
||||
|
||||
// try logic
|
||||
add(bn, (E | T, L) = "OR");
|
||||
add(bn, (E | T, L) = "AND");
|
||||
bn.add((E | T, L) = "OR");
|
||||
bn.add((E | T, L) = "AND");
|
||||
|
||||
// // try multivalued
|
||||
add(bn, C % "1/1/2");
|
||||
add(bn, C | S = "1/1/2 5/2/3");
|
||||
bn.add(C % "1/1/2");
|
||||
bn.add(C | S = "1/1/2 5/2/3");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -1,261 +1,261 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDiscreteBayesTree.cpp
|
||||
* @date sept 15, 2012
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/BayesTreeOrdered.h>
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
///* ----------------------------------------------------------------------------
|
||||
//
|
||||
// * GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
// * Atlanta, Georgia 30332-0415
|
||||
// * All Rights Reserved
|
||||
// * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
//
|
||||
// * See LICENSE for the license information
|
||||
//
|
||||
// * -------------------------------------------------------------------------- */
|
||||
//
|
||||
///*
|
||||
// * @file testDiscreteBayesTree.cpp
|
||||
// * @date sept 15, 2012
|
||||
// * @author Frank Dellaert
|
||||
// */
|
||||
//
|
||||
//#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
//#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
//#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
//
|
||||
//#include <boost/assign/std/vector.hpp>
|
||||
//using namespace boost::assign;
|
||||
//
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
static bool debug = false;
|
||||
|
||||
/**
|
||||
* Custom clique class to debug shortcuts
|
||||
*/
|
||||
class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
|
||||
|
||||
protected:
|
||||
|
||||
public:
|
||||
|
||||
typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
|
||||
typedef boost::shared_ptr<Clique> shared_ptr;
|
||||
|
||||
// Constructors
|
||||
Clique() {
|
||||
}
|
||||
Clique(const DiscreteConditional::shared_ptr& conditional) :
|
||||
Base(conditional) {
|
||||
}
|
||||
Clique(
|
||||
const std::pair<DiscreteConditional::shared_ptr,
|
||||
DiscreteConditional::FactorType::shared_ptr>& result) :
|
||||
Base(result) {
|
||||
}
|
||||
|
||||
/// print index signature only
|
||||
void printSignature(const std::string& s = "Clique: ",
|
||||
const IndexFormatter& indexFormatter = DefaultIndexFormatter) const {
|
||||
((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
|
||||
}
|
||||
|
||||
/// evaluate value of sub-tree
|
||||
double evaluate(const DiscreteConditional::Values & values) {
|
||||
double result = (*(this->conditional_))(values);
|
||||
// evaluate all children and multiply into result
|
||||
BOOST_FOREACH(boost::shared_ptr<Clique> c, children_)
|
||||
result *= c->evaluate(values);
|
||||
return result;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
|
||||
|
||||
/* ************************************************************************* */
|
||||
double evaluate(const DiscreteBayesTree& tree,
|
||||
const DiscreteConditional::Values & values) {
|
||||
return tree.root()->evaluate(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
||||
|
||||
const int nrNodes = 15;
|
||||
const size_t nrStates = 2;
|
||||
|
||||
// define variables
|
||||
vector<DiscreteKey> key;
|
||||
for (int i = 0; i < nrNodes; i++) {
|
||||
DiscreteKey key_i(i, nrStates);
|
||||
key.push_back(key_i);
|
||||
}
|
||||
|
||||
// create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
DiscreteBayesNet bayesNet;
|
||||
add_front(bayesNet, key[14] % "1/3");
|
||||
|
||||
add_front(bayesNet, key[13] | key[14] = "1/3 3/1");
|
||||
add_front(bayesNet, key[12] | key[14] = "3/1 3/1");
|
||||
|
||||
add_front(bayesNet, (key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
||||
add_front(bayesNet, (key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
||||
add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
||||
add_front(bayesNet, (key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
||||
|
||||
add_front(bayesNet, (key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
||||
add_front(bayesNet, (key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
||||
add_front(bayesNet, (key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
||||
add_front(bayesNet, (key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
add_front(bayesNet, (key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
||||
add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
||||
add_front(bayesNet, (key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
||||
add_front(bayesNet, (key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
||||
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesNet);
|
||||
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
}
|
||||
|
||||
// create a BayesTree out of a Bayes net
|
||||
DiscreteBayesTree bayesTree(bayesNet);
|
||||
if (debug) {
|
||||
GTSAM_PRINT(bayesTree);
|
||||
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
|
||||
}
|
||||
|
||||
// Check whether BN and BT give the same answer on all configurations
|
||||
// Also calculate all some marginals
|
||||
Vector marginals = zero(15);
|
||||
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||
joint_4_11 = 0;
|
||||
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
||||
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
DiscreteFactor::Values x = allPosbValues[i];
|
||||
double expected = evaluate(bayesNet, x);
|
||||
double actual = evaluate(bayesTree, x);
|
||||
DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
// collect marginals
|
||||
for (size_t i = 0; i < 15; i++)
|
||||
if (x[i])
|
||||
marginals[i] += actual;
|
||||
// calculate shortcut 8 and 0
|
||||
if (x[12] && x[14])
|
||||
joint_12_14 += actual;
|
||||
if (x[9] && x[12] & x[14])
|
||||
joint_9_12_14 += actual;
|
||||
if (x[8] && x[12] & x[14])
|
||||
joint_8_12_14 += actual;
|
||||
if (x[8] && x[12])
|
||||
joint_8_12 += actual;
|
||||
if (x[8] && x[2])
|
||||
joint82 += actual;
|
||||
if (x[1] && x[2])
|
||||
joint12 += actual;
|
||||
if (x[2] && x[4])
|
||||
joint24 += actual;
|
||||
if (x[4] && x[5])
|
||||
joint45 += actual;
|
||||
if (x[4] && x[6])
|
||||
joint46 += actual;
|
||||
if (x[4] && x[11])
|
||||
joint_4_11 += actual;
|
||||
}
|
||||
DiscreteFactor::Values all1 = allPosbValues.back();
|
||||
|
||||
Clique::shared_ptr R = bayesTree.root();
|
||||
|
||||
// check separator marginal P(S0)
|
||||
Clique::shared_ptr c = bayesTree[0];
|
||||
DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
|
||||
EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
|
||||
// check separator marginal P(S9), should be P(14)
|
||||
c = bayesTree[9];
|
||||
DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
|
||||
EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
|
||||
// check separator marginal of root, should be empty
|
||||
c = bayesTree[11];
|
||||
DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
|
||||
EliminateDiscrete);
|
||||
EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
|
||||
// check shortcut P(S9||R) to root
|
||||
c = bayesTree[9];
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_LONGS_EQUAL(0, shortcut.size());
|
||||
|
||||
// check shortcut P(S8||R) to root
|
||||
c = bayesTree[8];
|
||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
1e-9);
|
||||
|
||||
// check shortcut P(S2||R) to root
|
||||
c = bayesTree[2];
|
||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
1e-9);
|
||||
|
||||
// check shortcut P(S0||R) to root
|
||||
c = bayesTree[0];
|
||||
shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
1e-9);
|
||||
|
||||
// calculate all shortcuts to root
|
||||
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
|
||||
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
|
||||
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
if (debug) {
|
||||
c->printSignature();
|
||||
shortcut.print("shortcut:");
|
||||
}
|
||||
}
|
||||
|
||||
// Check all marginals
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
for (size_t i = 0; i < 15; i++) {
|
||||
marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
|
||||
double actual = (*marginalFactor)(all1);
|
||||
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
}
|
||||
|
||||
DiscreteBayesNet::shared_ptr actualJoint;
|
||||
|
||||
// Check joint P(8,2) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(1,2) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(2,4)
|
||||
actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(4,5) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(4,6) TODO: not disjoint !
|
||||
// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
// Check joint P(4,11)
|
||||
actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
|
||||
EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
|
||||
|
||||
}
|
||||
//
|
||||
//using namespace std;
|
||||
//using namespace gtsam;
|
||||
//
|
||||
//static bool debug = false;
|
||||
//
|
||||
///**
|
||||
// * Custom clique class to debug shortcuts
|
||||
// */
|
||||
////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
|
||||
////
|
||||
////protected:
|
||||
////
|
||||
////public:
|
||||
////
|
||||
//// typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
|
||||
//// typedef boost::shared_ptr<Clique> shared_ptr;
|
||||
////
|
||||
//// // Constructors
|
||||
//// Clique() {
|
||||
//// }
|
||||
//// Clique(const DiscreteConditional::shared_ptr& conditional) :
|
||||
//// Base(conditional) {
|
||||
//// }
|
||||
//// Clique(
|
||||
//// const std::pair<DiscreteConditional::shared_ptr,
|
||||
//// DiscreteConditional::FactorType::shared_ptr>& result) :
|
||||
//// Base(result) {
|
||||
//// }
|
||||
////
|
||||
//// /// print index signature only
|
||||
//// void printSignature(const std::string& s = "Clique: ",
|
||||
//// const IndexFormatter& indexFormatter = DefaultIndexFormatter) const {
|
||||
//// ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
|
||||
//// }
|
||||
////
|
||||
//// /// evaluate value of sub-tree
|
||||
//// double evaluate(const DiscreteConditional::Values & values) {
|
||||
//// double result = (*(this->conditional_))(values);
|
||||
//// // evaluate all children and multiply into result
|
||||
//// BOOST_FOREACH(boost::shared_ptr<Clique> c, children_)
|
||||
//// result *= c->evaluate(values);
|
||||
//// return result;
|
||||
//// }
|
||||
////
|
||||
////};
|
||||
//
|
||||
////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
|
||||
////
|
||||
/////* ************************************************************************* */
|
||||
////double evaluate(const DiscreteBayesTree& tree,
|
||||
//// const DiscreteConditional::Values & values) {
|
||||
//// return tree.root()->evaluate(values);
|
||||
////}
|
||||
//
|
||||
///* ************************************************************************* */
|
||||
//
|
||||
//TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
|
||||
//
|
||||
// const int nrNodes = 15;
|
||||
// const size_t nrStates = 2;
|
||||
//
|
||||
// // define variables
|
||||
// vector<DiscreteKey> key;
|
||||
// for (int i = 0; i < nrNodes; i++) {
|
||||
// DiscreteKey key_i(i, nrStates);
|
||||
// key.push_back(key_i);
|
||||
// }
|
||||
//
|
||||
// // create a thin-tree Bayesnet, a la Jean-Guillaume
|
||||
// DiscreteBayesNet bayesNet;
|
||||
// bayesNet.add(key[14] % "1/3");
|
||||
//
|
||||
// bayesNet.add(key[13] | key[14] = "1/3 3/1");
|
||||
// bayesNet.add(key[12] | key[14] = "3/1 3/1");
|
||||
//
|
||||
// bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
|
||||
// bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
|
||||
// bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
|
||||
// bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
|
||||
//
|
||||
// bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
|
||||
// bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
|
||||
// bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
|
||||
// bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
|
||||
//
|
||||
// bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
|
||||
// bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
|
||||
// bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
|
||||
// bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
|
||||
//
|
||||
//// if (debug) {
|
||||
//// GTSAM_PRINT(bayesNet);
|
||||
//// bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
|
||||
//// }
|
||||
//
|
||||
// // create a BayesTree out of a Bayes net
|
||||
// DiscreteBayesTree bayesTree(bayesNet);
|
||||
// if (debug) {
|
||||
// GTSAM_PRINT(bayesTree);
|
||||
// bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
|
||||
// }
|
||||
//
|
||||
// // Check whether BN and BT give the same answer on all configurations
|
||||
// // Also calculate all some marginals
|
||||
// Vector marginals = zero(15);
|
||||
// double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
|
||||
// joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
|
||||
// joint_4_11 = 0;
|
||||
// vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
|
||||
// key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
|
||||
// & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
|
||||
// for (size_t i = 0; i < allPosbValues.size(); ++i) {
|
||||
// DiscreteFactor::Values x = allPosbValues[i];
|
||||
// double expected = evaluate(bayesNet, x);
|
||||
// double actual = evaluate(bayesTree, x);
|
||||
// DOUBLES_EQUAL(expected, actual, 1e-9);
|
||||
// // collect marginals
|
||||
// for (size_t i = 0; i < 15; i++)
|
||||
// if (x[i])
|
||||
// marginals[i] += actual;
|
||||
// // calculate shortcut 8 and 0
|
||||
// if (x[12] && x[14])
|
||||
// joint_12_14 += actual;
|
||||
// if (x[9] && x[12] & x[14])
|
||||
// joint_9_12_14 += actual;
|
||||
// if (x[8] && x[12] & x[14])
|
||||
// joint_8_12_14 += actual;
|
||||
// if (x[8] && x[12])
|
||||
// joint_8_12 += actual;
|
||||
// if (x[8] && x[2])
|
||||
// joint82 += actual;
|
||||
// if (x[1] && x[2])
|
||||
// joint12 += actual;
|
||||
// if (x[2] && x[4])
|
||||
// joint24 += actual;
|
||||
// if (x[4] && x[5])
|
||||
// joint45 += actual;
|
||||
// if (x[4] && x[6])
|
||||
// joint46 += actual;
|
||||
// if (x[4] && x[11])
|
||||
// joint_4_11 += actual;
|
||||
// }
|
||||
// DiscreteFactor::Values all1 = allPosbValues.back();
|
||||
//
|
||||
// Clique::shared_ptr R = bayesTree.root();
|
||||
//
|
||||
// // check separator marginal P(S0)
|
||||
// Clique::shared_ptr c = bayesTree[0];
|
||||
// DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
|
||||
// EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
|
||||
//
|
||||
// // check separator marginal P(S9), should be P(14)
|
||||
// c = bayesTree[9];
|
||||
// DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
|
||||
// EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
|
||||
//
|
||||
// // check separator marginal of root, should be empty
|
||||
// c = bayesTree[11];
|
||||
// DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
|
||||
// EliminateDiscrete);
|
||||
// EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
|
||||
//
|
||||
// // check shortcut P(S9||R) to root
|
||||
// c = bayesTree[9];
|
||||
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
// EXPECT_LONGS_EQUAL(0, shortcut.size());
|
||||
//
|
||||
// // check shortcut P(S8||R) to root
|
||||
// c = bayesTree[8];
|
||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
// 1e-9);
|
||||
//
|
||||
// // check shortcut P(S2||R) to root
|
||||
// c = bayesTree[2];
|
||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
// 1e-9);
|
||||
//
|
||||
// // check shortcut P(S0||R) to root
|
||||
// c = bayesTree[0];
|
||||
// shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
|
||||
// 1e-9);
|
||||
//
|
||||
// // calculate all shortcuts to root
|
||||
// DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
|
||||
// BOOST_FOREACH(Clique::shared_ptr c, cliques) {
|
||||
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
|
||||
// if (debug) {
|
||||
// c->printSignature();
|
||||
// shortcut.print("shortcut:");
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Check all marginals
|
||||
// DiscreteFactor::shared_ptr marginalFactor;
|
||||
// for (size_t i = 0; i < 15; i++) {
|
||||
// marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
|
||||
// double actual = (*marginalFactor)(all1);
|
||||
// EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
|
||||
// }
|
||||
//
|
||||
// DiscreteBayesNet::shared_ptr actualJoint;
|
||||
//
|
||||
// // Check joint P(8,2) TODO: not disjoint !
|
||||
//// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
|
||||
//// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
|
||||
//
|
||||
// // Check joint P(1,2) TODO: not disjoint !
|
||||
//// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
|
||||
//// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
|
||||
//
|
||||
// // Check joint P(2,4)
|
||||
// actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
|
||||
//
|
||||
// // Check joint P(4,5) TODO: not disjoint !
|
||||
//// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
|
||||
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||
//
|
||||
// // Check joint P(4,6) TODO: not disjoint !
|
||||
//// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
|
||||
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
|
||||
//
|
||||
// // Check joint P(4,11)
|
||||
// actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
|
||||
// EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
|
||||
//
|
||||
//}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteSequentialSolver.h>
|
||||
#include <gtsam/inference/GenericMultifrontalSolver.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteBayesTree.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
|
@ -58,9 +58,9 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
|||
// result.first->print("BayesNet: ");
|
||||
// result.second->print("New factor: ");
|
||||
//
|
||||
DiscreteSequentialSolver solver(graph);
|
||||
// solver.print("solver:");
|
||||
EliminationTreeOrdered<DiscreteFactor> eliminationTree = solver.eliminationTree();
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
||||
// eliminationTree.print("Elimination tree: ");
|
||||
eliminationTree.eliminate(EliminateDiscrete);
|
||||
// solver.optimize();
|
||||
|
@ -127,10 +127,11 @@ TEST( DiscreteFactorGraph, test)
|
|||
|
||||
// Test EliminateDiscrete
|
||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
||||
Ordering frontalKeys;
|
||||
frontalKeys += Key(0);
|
||||
DiscreteConditional::shared_ptr conditional;
|
||||
DecisionTreeFactor::shared_ptr newFactor;
|
||||
boost::tie(conditional, newFactor) =//
|
||||
EliminateDiscrete(graph, 1);
|
||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
// Check Bayes net
|
||||
CHECK(conditional);
|
||||
|
@ -139,34 +140,35 @@ TEST( DiscreteFactorGraph, test)
|
|||
// cout << signature << endl;
|
||||
DiscreteConditional expectedConditional(signature);
|
||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||
add(expected, signature);
|
||||
expected.add(signature);
|
||||
|
||||
// Check Factor
|
||||
CHECK(newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||
|
||||
// Create solver
|
||||
DiscreteSequentialSolver solver(graph);
|
||||
|
||||
// add conditionals to complete expected Bayes net
|
||||
add(expected, B | A = "5/3 3/5");
|
||||
add(expected, A % "1/1");
|
||||
expected.add(B | A = "5/3 3/5");
|
||||
expected.add(A % "1/1");
|
||||
// GTSAM_PRINT(expected);
|
||||
|
||||
// Test elimination tree
|
||||
const EliminationTreeOrdered<DiscreteFactor>& etree = solver.eliminationTree();
|
||||
DiscreteBayesNet::shared_ptr actual = etree.eliminate(&EliminateDiscrete);
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2);
|
||||
DiscreteEliminationTree etree(graph, ordering);
|
||||
DiscreteBayesNet::shared_ptr actual;
|
||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||
EXPECT(assert_equal(expected, *actual));
|
||||
|
||||
// Test solver
|
||||
DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
||||
EXPECT(assert_equal(expected, *actual2));
|
||||
// // Test solver
|
||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
||||
// EXPECT(assert_equal(expected, *actual2));
|
||||
|
||||
// Test optimization
|
||||
DiscreteFactor::Values expectedValues;
|
||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
||||
DiscreteFactor::sharedValues actualValues = solver.optimize();
|
||||
DiscreteFactor::sharedValues actualValues = graph.optimize();
|
||||
EXPECT(assert_equal(expectedValues, *actualValues));
|
||||
}
|
||||
|
||||
|
@ -183,8 +185,7 @@ TEST( DiscreteFactorGraph, testMPE)
|
|||
// graph.product().print();
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
|
||||
DiscreteFactor::sharedValues actualMPE =
|
||||
DiscreteSequentialSolver(graph).optimize();
|
||||
DiscreteFactor::sharedValues actualMPE = graph.optimize();
|
||||
|
||||
DiscreteFactor::Values expectedMPE;
|
||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
||||
|
@ -213,18 +214,23 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
|||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
||||
|
||||
// Use the solver machinery.
|
||||
DiscreteBayesNet::shared_ptr chordal =
|
||||
DiscreteSequentialSolver(graph).eliminate();
|
||||
DiscreteFactor::sharedValues actualMPE = optimize(*chordal);
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
|
||||
EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
||||
typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
||||
GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
||||
BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
||||
//// bayesTree->print("Bayes Tree");
|
||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
|
||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
|
||||
#ifdef OLD
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteMarginals.h>
|
||||
#include <gtsam/discrete/DiscreteSequentialSolver.h>
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
using namespace boost::assign;
|
||||
|
@ -119,11 +118,9 @@ TEST_UNSAFE( DiscreteMarginals, truss ) {
|
|||
graph.add(key[0] & key[2] & key[4],"2 2 2 2 1 1 1 1");
|
||||
graph.add(key[1] & key[3] & key[4],"1 1 1 1 2 2 2 2");
|
||||
graph.add(key[2] & key[3] & key[4],"1 1 1 1 1 1 1 1");
|
||||
typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
||||
GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
||||
BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal();
|
||||
// bayesTree->print("Bayes Tree");
|
||||
typedef BayesTreeClique<DiscreteConditional> Clique;
|
||||
typedef DiscreteBayesTreeClique Clique;
|
||||
|
||||
Clique expected0(boost::make_shared<DiscreteConditional>((key[0] | key[2], key[4]) = "2/1 2/1 2/1 2/1"));
|
||||
Clique::shared_ptr actual0 = (*bayesTree)[0];
|
||||
|
@ -180,16 +177,16 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) {
|
|||
}
|
||||
|
||||
// Check all marginals given by a sequential solver and Marginals
|
||||
DiscreteSequentialSolver solver(graph);
|
||||
// DiscreteSequentialSolver solver(graph);
|
||||
DiscreteMarginals marginals(graph);
|
||||
for (size_t j=0;j<5;j++) {
|
||||
double sum = T[j]+F[j];
|
||||
T[j]/=sum;
|
||||
F[j]/=sum;
|
||||
|
||||
// solver
|
||||
Vector actualV = solver.marginalProbabilities(key[j]);
|
||||
EXPECT(assert_equal(Vector_(2,F[j],T[j]), actualV));
|
||||
// // solver
|
||||
// Vector actualV = solver.marginalProbabilities(key[j]);
|
||||
// EXPECT(assert_equal(Vector_(2,F[j],T[j]), actualV));
|
||||
|
||||
// Marginals
|
||||
vector<double> table;
|
||||
|
|
Loading…
Reference in New Issue