revive discrete, except testDiscreteBayesNet::sample, and the whole testDiscreteBayesTree, currently disabled. It's still a mess!!!

release/4.3a0
Duy-Nguyen Ta 2013-10-10 21:59:49 +00:00
parent 1a8a670870
commit 9aede467d2
30 changed files with 1347 additions and 1042 deletions

View File

@ -6,7 +6,7 @@ set (gtsam_subdirs
geometry geometry
inference inference
symbolic symbolic
#discrete discrete
linear linear
nonlinear nonlinear
slam slam
@ -71,7 +71,7 @@ set(gtsam_srcs
${geometry_srcs} ${geometry_srcs}
${inference_srcs} ${inference_srcs}
${symbolic_srcs} ${symbolic_srcs}
#${discrete_srcs} ${discrete_srcs}
${linear_srcs} ${linear_srcs}
${nonlinear_srcs} ${nonlinear_srcs}
${slam_srcs} ${slam_srcs}

View File

@ -11,7 +11,7 @@
/** /**
* @file DecisionTree.h * @file DecisionTree.h
* @brief Decision Tree for use in DiscreteFactors * @brief Decision Tree for use in DiscreteFactors
* @author Frank Dellaert * @author Frank Dellaert
* @author Can Erdogan * @author Can Erdogan
* @date Jan 30, 2012 * @date Jan 30, 2012
@ -643,6 +643,7 @@ namespace gtsam {
// where there is no more branch on "label": only the subtree under that // where there is no more branch on "label": only the subtree under that
// branch point corresponding to the value "index" is left instead. // branch point corresponding to the value "index" is left instead.
// The function below get all these smaller trees and "ops" them together. // The function below get all these smaller trees and "ops" them together.
// This implements marginalization in Darwiche09book, pg 330
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::combine(const L& label, DecisionTree<L, Y> DecisionTree<L, Y>::combine(const L& label,
size_t cardinality, const Binary& op) const { size_t cardinality, const Binary& op) const {

View File

@ -11,7 +11,7 @@
/** /**
* @file DecisionTree.h * @file DecisionTree.h
* @brief Decision Tree for use in DiscreteFactors * @brief Decision Tree for use in DiscreteFactors
* @author Frank Dellaert * @author Frank Dellaert
* @author Can Erdogan * @author Can Erdogan
* @date Jan 30, 2012 * @date Jan 30, 2012
@ -177,7 +177,8 @@ namespace gtsam {
/** apply binary operation "op" to f and g */ /** apply binary operation "op" to f and g */
DecisionTree apply(const DecisionTree& g, const Binary& op) const; DecisionTree apply(const DecisionTree& g, const Binary& op) const;
/** create a new function where value(label)==index */ /** create a new function where value(label)==index
* It's like "restrict" in Darwiche09book pg329, 330? */
DecisionTree choose(const L& label, size_t index) const { DecisionTree choose(const L& label, size_t index) const {
NodePtr newRoot = root_->choose(label, index); NodePtr newRoot = root_->choose(label, index);
return DecisionTree(newRoot); return DecisionTree(newRoot);

View File

@ -44,21 +44,25 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool DecisionTreeFactor::equals(const This& other, double tol) const { bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
return IndexFactorOrdered::equals(other, tol) && Potentials::equals(other, tol); if(!dynamic_cast<const DecisionTreeFactor*>(&other))
return false;
else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const IndexFormatter& formatter) const { const IndexFormatter& formatter) const {
cout << s; cout << s;
IndexFactorOrdered::print("IndexFactor:",formatter);
Potentials::print("Potentials:",formatter); Potentials::print("Potentials:",formatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor DecisionTreeFactor::apply // DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
(const DecisionTreeFactor& f, ADT::Binary op) const { ADT::Binary op) const {
map<Index,size_t> cs; // new cardinalities map<Index,size_t> cs; // new cardinalities
// make unique key-cardinality map // make unique key-cardinality map
BOOST_FOREACH(Index j, keys()) cs[j] = cardinality(j); BOOST_FOREACH(Index j, keys()) cs[j] = cardinality(j);
@ -74,8 +78,8 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine // DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
(size_t nrFrontals, ADT::Binary op) const { ADT::Binary op) const {
if (nrFrontals > size()) throw invalid_argument( if (nrFrontals > size()) throw invalid_argument(
(boost::format( (boost::format(
@ -99,5 +103,36 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
ADT::Binary op) const {
if (frontalKeys.size() > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% frontalKeys.size() % size()).str());
// sum over nrFrontals keys
size_t i;
ADT result(*this);
for (i = 0; i < frontalKeys.size(); i++) {
Index j = frontalKeys[i];
result = result.combine(j, cardinality(j), op);
}
// create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) {
Index j = keys()[i];
// TODO: inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
continue;
dkeys.push_back(DiscreteKey(j,cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h> #include <gtsam/discrete/Potentials.h>
#include <gtsam/inference/Ordering.h>
#include <boost/foreach.hpp> #include <boost/foreach.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
@ -41,7 +42,7 @@ namespace gtsam {
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This; typedef DecisionTreeFactor This;
typedef DiscreteConditional ConditionalType; typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr; typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
public: public:
@ -69,7 +70,7 @@ namespace gtsam {
/// @{ /// @{
/// equality /// equality
bool equals(const DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
// print // print
virtual void print(const std::string& s = "DecisionTreeFactor:\n", virtual void print(const std::string& s = "DecisionTreeFactor:\n",
@ -104,6 +105,11 @@ namespace gtsam {
return combine(nrFrontals, ADT::Ring::add); return combine(nrFrontals, ADT::Ring::add);
} }
/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
return combine(keys, ADT::Ring::add);
}
/// Create new factor by maximizing over all values with the same separator values /// Create new factor by maximizing over all values with the same separator values
shared_ptr max(size_t nrFrontals) const { shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max); return combine(nrFrontals, ADT::Ring::max);
@ -129,27 +135,39 @@ namespace gtsam {
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
/** /**
* @brief Permutes the keys in Potentials and DiscreteFactor * Combine frontal variables in an Ordering using binary operator "op"
* * @param nrFrontals nr. of frontal to combine variables in this factor
* This re-implements the permuteWithInverse() in both Potentials * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* and DiscreteFactor by doing both of them together. * @return shared pointer to newly created DecisionTreeFactor
*/ */
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
void permuteWithInverse(const Permutation& inversePermutation){
DiscreteFactor::permuteWithInverse(inversePermutation); /** Test whether the factor is empty */
Potentials::permuteWithInverse(inversePermutation); virtual bool empty() const { return size() == 0; }
}
// /**
/** // * @brief Permutes the keys in Potentials and DiscreteFactor
* Apply a reduction, which is a remapping of variable indices. // *
*/ // * This re-implements the permuteWithInverse() in both Potentials
virtual void reduceWithInverse(const internal::Reduction& inverseReduction) { // * and DiscreteFactor by doing both of them together.
DiscreteFactor::reduceWithInverse(inverseReduction); // */
Potentials::reduceWithInverse(inverseReduction); //
} // void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// @} /// @}
}; };
// DecisionTreeFactor // DecisionTreeFactor
}// namespace gtsam }// namespace gtsam

View File

@ -18,47 +18,53 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/BayesNetOrdered-inl.h> #include <gtsam/inference/FactorGraph-inst.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
// Explicitly instantiate so we don't have to include everywhere // Instantiate base class
template class BayesNetOrdered<DiscreteConditional> ; template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */ /* ************************************************************************* */
void add_front(DiscreteBayesNet& bayesNet, const Signature& s) { bool DiscreteBayesNet::equals(const This& bn, double tol) const
bayesNet.push_front(boost::make_shared<DiscreteConditional>(s)); {
return Base::equals(bn, tol);
} }
/* ************************************************************************* */ /* ************************************************************************* */
void add(DiscreteBayesNet& bayesNet, const Signature& s) { // void DiscreteBayesNet::add_front(const Signature& s) {
bayesNet.push_back(boost::make_shared<DiscreteConditional>(s)); // push_front(boost::make_shared<DiscreteConditional>(s));
// }
/* ************************************************************************* */
void DiscreteBayesNet::add(const Signature& s) {
push_back(boost::make_shared<DiscreteConditional>(s));
} }
/* ************************************************************************* */ /* ************************************************************************* */
double evaluate(const DiscreteBayesNet& bn, const DiscreteConditional::Values & values) { double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
// evaluate all conditionals and multiply // evaluate all conditionals and multiply
double result = 1.0; double result = 1.0;
BOOST_FOREACH(DiscreteConditional::shared_ptr conditional, bn) BOOST_FOREACH(DiscreteConditional::shared_ptr conditional, *this)
result *= (*conditional)(values); result *= (*conditional)(values);
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn) { DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first) // solve each node in turn in topological sort order (parents first)
DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, bn) BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, *this)
conditional->solveInPlace(*result); conditional->solveInPlace(*result);
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn) { DiscreteFactor::sharedValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first) // sample each node in turn in topological sort order (parents first)
DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, bn) BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, *this)
conditional->sampleInPlace(*result); conditional->sampleInPlace(*result);
return result; return result;
} }

View File

@ -20,27 +20,80 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <gtsam/inference/BayesNetOrdered.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
namespace gtsam { namespace gtsam {
typedef BayesNetOrdered<DiscreteConditional> DiscreteBayesNet; /** A Bayes net made from linear-Discrete densities */
class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional>
{
public:
/** Add a DiscreteCondtional */ typedef FactorGraph<DiscreteConditional> Base;
GTSAM_EXPORT void add(DiscreteBayesNet&, const Signature& s); typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::shared_ptr<ConditionalType> sharedConditional;
/** Add a DiscreteCondtional in front, when listing parents first*/ /// @name Standard Constructors
GTSAM_EXPORT void add_front(DiscreteBayesNet&, const Signature& s); /// @{
//** evaluate for given Values */ /** Construct empty factor graph */
GTSAM_EXPORT double evaluate(const DiscreteBayesNet& bn, const DiscreteConditional::Values & values); DiscreteBayesNet() {}
/** Optimize function for back-substitution. */ /** Construct from iterator over conditionals */
GTSAM_EXPORT DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn); template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
/** Do ancestral sampling */ /** Construct from container of factors (shared_ptr or plain objects) */
GTSAM_EXPORT DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn); template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteCondtional */
GTSAM_EXPORT void add(const Signature& s);
// /** Add a DiscreteCondtional in front, when listing parents first*/
// GTSAM_EXPORT void add_front(const Signature& s);
//** evaluate for given Values */
GTSAM_EXPORT double evaluate(const DiscreteConditional::Values & values) const;
/**
* Solve the DiscreteBayesNet by back-substitution
*/
DiscreteFactor::sharedValues optimize() const;
/** Do ancestral sampling */
GTSAM_EXPORT DiscreteFactor::sharedValues sample() const;
///@}
private:
/** Serialization function */
friend class boost::serialization::access;
template<class ARCHIVE>
void serialize(ARCHIVE & ar, const unsigned int version) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
} // namespace } // namespace

View File

@ -0,0 +1,43 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteBayesTree.cpp
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
* @brief DiscreteBayesTree
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
namespace gtsam {
// Instantiate base class
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
template class BayesTree<DiscreteBayesTreeClique>;
/* ************************************************************************* */
bool DiscreteBayesTree::equals(const This& other, double tol) const
{
return Base::equals(other, tol);
}
} // \namespace gtsam

View File

@ -0,0 +1,66 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteBayesTree.h
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
* @brief DiscreteBayesTree
* @author Frank Dellaert
* @author Richard Roberts
*/
#pragma once
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h>
namespace gtsam {
// Forward declarations
class DiscreteConditional;
class VectorValues;
/* ************************************************************************* */
/** A clique in a DiscreteBayesTree */
class GTSAM_EXPORT DiscreteBayesTreeClique :
public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>
{
public:
typedef DiscreteBayesTreeClique This;
typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> Base;
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr;
DiscreteBayesTreeClique() {}
DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {}
};
/* ************************************************************************* */
/** A Bayes tree representing a Discrete density */
class GTSAM_EXPORT DiscreteBayesTree :
public BayesTree<DiscreteBayesTreeClique>
{
private:
typedef BayesTree<DiscreteBayesTreeClique> Base;
public:
typedef DiscreteBayesTree This;
typedef boost::shared_ptr<This> shared_ptr;
/** Default constructor, creates an empty Bayes tree */
DiscreteBayesTree() {}
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;
};
}

View File

@ -18,6 +18,7 @@
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
@ -34,174 +35,188 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ******************************************************************************** */ // Instantiate base class
DiscreteConditional::DiscreteConditional(const size_t nrFrontals, template class Conditional<DecisionTreeFactor, DiscreteConditional> ;
const DecisionTreeFactor& f) :
IndexConditionalOrdered(f.keys(), nrFrontals), Potentials(
f / (*f.sum(nrFrontals))) {
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal) :
IndexConditionalOrdered(joint.keys(), joint.size() - marginal.size()), Potentials(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
cout << (firstFrontalKey()) << endl; //TODO Print all keys
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature) :
IndexConditionalOrdered(signature.indices(), 1), Potentials(
signature.discreteKeysParentsFirst(), signature.cpt()) {
}
/* ******************************************************************************** */
void DiscreteConditional::print(const std::string& s, const IndexFormatter& formatter) const {
std::cout << s << std::endl;
IndexConditionalOrdered::print(s, formatter);
Potentials::print(s);
}
/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteConditional& other, double tol) const {
return IndexConditionalOrdered::equals(other, tol)
&& Potentials::equals(other, tol);
}
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(
const Values& parentsValues) const {
ADT pFS(*this);
BOOST_FOREACH(Index key, parents())
try {
Index j = (key);
size_t value = parentsValues.at(j);
pFS = pFS.choose(j, value);
} catch (exception&) {
throw runtime_error(
"DiscreteConditional::choose: parent value missing");
};
return pFS;
}
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values& values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
ADT pFS = choose(values); // P(F|S=parentsValues)
// Initialize
Values mpe;
double maxP = 0;
DiscreteKeys keys;
BOOST_FOREACH(Index idx, frontals()) {
DiscreteKey dk(idx,cardinality(idx));
keys & dk;
}
// Get all Possible Configurations
vector<Values> allPosbValues = cartesianProduct(keys);
// Find the MPE
BOOST_FOREACH(Values& frontalVals, allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
//set values (inPlace) to mpe
BOOST_FOREACH(Index j, frontals()) {
values[j] = mpe[j];
}
}
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(Values& values) const {
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
size_t sampled = sample(values); // Sample variable
values[j] = sampled; // store result in partial solution
}
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const Values& parentsValues) const {
// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0;
Values frontals;
double maxP = 0;
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const {
using boost::uniform_real;
static boost::mt19937 gen(2); // random number generator
bool debug = ISDEBUG("DiscreteConditional::sample");
// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
if (debug) GTSAM_PRINT(pFS);
// get cumulative distribution function (cdf)
// TODO, only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
size_t nj = cardinality(j);
vector<double> cdf(nj);
Values frontals;
double sum = 0;
for (size_t value = 0; value < nj; value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
sum += pValueS; // accumulate
if (debug) cout << sum << " ";
if (pValueS == 1) {
if (debug) cout << "--> " << value << endl;
return value; // shortcut exit
}
cdf[value] = sum;
}
// inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html
uniform_real<> dist(0, cdf.back());
boost::variate_generator<boost::mt19937&, uniform_real<> > die(gen, dist);
size_t sampled = lower_bound(cdf.begin(), cdf.end(), die()) - cdf.begin();
if (debug) cout << "-> " << sampled << endl;
return sampled;
return 0;
}
/* ******************************************************************************** */
void DiscreteConditional::permuteWithInverse(const Permutation& inversePermutation){
IndexConditionalOrdered::permuteWithInverse(inversePermutation);
Potentials::permuteWithInverse(inversePermutation);
}
/* ******************************************************************************** */ /* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) :
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
}
} // namespace /* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal, const boost::optional<Ordering>& orderedKeys) :
BaseFactor(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
joint.size()-marginal.size()) {
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
cout << (firstFrontalKey()) << endl; //TODO Print all keys
if (orderedKeys) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys->begin(), orderedKeys->end());
}
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature) :
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
1) {
}
/* ******************************************************************************** */
void DiscreteConditional::print(const std::string& s,
const IndexFormatter& formatter) const {
std::cout << s << std::endl;
Potentials::print(s);
}
/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteConditional& other,
double tol) const {
return Potentials::equals(other, tol);
}
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
ADT pFS(*this);
Index j; size_t value;
BOOST_FOREACH(Index key, parents())
try {
j = (key);
value = parentsValues.at(j);
pFS = pFS.choose(j, value);
} catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
pFS.print("pFS: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
return pFS;
}
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values& values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
print("Me: ");
values.print("values:");
cout << "nrFrontals/nrParents: " << nrFrontals() << "/" << nrParents() << endl;
cout << "first frontal: " << firstFrontalKey() << endl;
ADT pFS = choose(values); // P(F|S=parentsValues)
// Initialize
Values mpe;
double maxP = 0;
DiscreteKeys keys;
BOOST_FOREACH(Index idx, frontals()) {
DiscreteKey dk(idx, cardinality(idx));
keys & dk;
}
// Get all Possible Configurations
vector<Values> allPosbValues = cartesianProduct(keys);
// Find the MPE
BOOST_FOREACH(Values& frontalVals, allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
//set values (inPlace) to mpe
BOOST_FOREACH(Index j, frontals()) {
values[j] = mpe[j];
}
}
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(Values& values) const {
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
size_t sampled = sample(values); // Sample variable
values[j] = sampled; // store result in partial solution
}
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const Values& parentsValues) const {
// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0;
Values frontals;
double maxP = 0;
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const {
using boost::uniform_real;
static boost::mt19937 gen(2); // random number generator
bool debug = ISDEBUG("DiscreteConditional::sample");
// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
if (debug)
GTSAM_PRINT(pFS);
// get cumulative distribution function (cdf)
// TODO, only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
size_t nj = cardinality(j);
vector<double> cdf(nj);
Values frontals;
double sum = 0;
for (size_t value = 0; value < nj; value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
sum += pValueS; // accumulate
if (debug)
cout << sum << " ";
if (pValueS == 1) {
if (debug)
cout << "--> " << value << endl;
return value; // shortcut exit
}
cdf[value] = sum;
}
// inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html
uniform_real<> dist(0, cdf.back());
boost::variate_generator<boost::mt19937&, uniform_real<> > die(gen, dist);
size_t sampled = lower_bound(cdf.begin(), cdf.end(), die()) - cdf.begin();
if (debug)
cout << "-> " << sampled << endl;
return sampled;
return 0;
}
/* ******************************************************************************** */
//void DiscreteConditional::permuteWithInverse(
// const Permutation& inversePermutation) {
// IndexConditionalOrdered::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
//}
/* ******************************************************************************** */
}// namespace

View File

@ -20,134 +20,134 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/IndexConditionalOrdered.h> #include <gtsam/inference/Conditional.h>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
/** /**
* Discrete Conditional Density * Discrete Conditional Density
* Derives from DecisionTreeFactor * Derives from DecisionTreeFactor
*/ */
class GTSAM_EXPORT DiscreteConditional: public IndexConditionalOrdered, public Potentials { class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
public Conditional<DecisionTreeFactor, DiscreteConditional> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DiscreteFactor FactorType; typedef DiscreteConditional This; ///< Typedef to this class
typedef boost::shared_ptr<DiscreteConditional> shared_ptr; typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef IndexConditionalOrdered Base; typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
/** A map from keys to values */ /** A map from keys to values..
typedef Assignment<Index> Values; * TODO: Again, do we need this??? */
typedef boost::shared_ptr<Values> sharedValues; typedef Assignment<Index> Values;
typedef boost::shared_ptr<Values> sharedValues;
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** default constructor needed for serialization */ /** default constructor needed for serialization */
DiscreteConditional() { DiscreteConditional() {
}
/** constructor from factor */
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
/** Construct from signature */
DiscreteConditional(const Signature& signature);
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
/**
* Combine several conditional into a single one.
* The conditionals must be given in increasing order, meaning that the parents
* of any conditional may not include a conditional coming before it.
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* */
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(const std::string& s = "Discrete Conditional: ",
const IndexFormatter& formatter = DefaultIndexFormatter) const;
/// GTSAM-style equals
bool equals(const DiscreteConditional& other, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/// Evaluate, just look up in AlgebraicDecisonTree
virtual double operator()(const Values& values) const {
return Potentials::operator()(values);
}
/** Convert to a factor */
DecisionTreeFactor::shared_ptr toFactor() const {
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
}
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const Assignment<Index>& parentsValues) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const Values& parentsValues) const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
size_t sample(const Values& parentsValues) const;
/// @}
/// @name Advanced Interface
/// @{
/// solve a conditional, in place
void solveInPlace(Values& parentsValues) const;
/// sample in place, stores result in partial solution
void sampleInPlace(Values& parentsValues) const;
/**
* Permutes both IndexConditional and Potentials.
*/
void permuteWithInverse(const Permutation& inversePermutation);
/// @}
};
// DiscreteConditional
/* ************************************************************************* */
template<typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique
// multiply all the potentials of the given conditionals
size_t nrFrontals = 0;
DecisionTreeFactor product;
for(ITERATOR it = firstConditional; it != lastConditional; ++it, ++nrFrontals) {
DiscreteConditional::shared_ptr c = *it;
DecisionTreeFactor::shared_ptr factor = c->toFactor();
product = (*factor) * product;
}
// and then create a new multi-frontal conditional
return boost::make_shared<DiscreteConditional>(nrFrontals,product);
} }
}// gtsam /** constructor from factor */
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
/** Construct from signature */
DiscreteConditional(const Signature& signature);
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal, const boost::optional<Ordering>& orderedKeys = boost::none);
/**
* Combine several conditional into a single one.
* The conditionals must be given in increasing order, meaning that the parents
* of any conditional may not include a conditional coming before it.
* @param firstConditional Iterator to the first conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* */
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional,
ITERATOR lastConditional);
/// @}
/// @name Testable
/// @{
/// GTSAM-style print
void print(const std::string& s = "Discrete Conditional: ",
const IndexFormatter& formatter = DefaultIndexFormatter) const;
/// GTSAM-style equals
bool equals(const DiscreteConditional& other, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/// Evaluate, just look up in AlgebraicDecisonTree
virtual double operator()(const Values& values) const {
return Potentials::operator()(values);
}
/** Convert to a factor */
DecisionTreeFactor::shared_ptr toFactor() const {
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
}
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const Assignment<Index>& parentsValues) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const Values& parentsValues) const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
size_t sample(const Values& parentsValues) const;
/// @}
/// @name Advanced Interface
/// @{
/// solve a conditional, in place
void solveInPlace(Values& parentsValues) const;
/// sample in place, stores result in partial solution
void sampleInPlace(Values& parentsValues) const;
/// @}
};
// DiscreteConditional
/* ************************************************************************* */
template<typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique
// multiply all the potentials of the given conditionals
size_t nrFrontals = 0;
DecisionTreeFactor product;
for (ITERATOR it = firstConditional; it != lastConditional;
++it, ++nrFrontals) {
DiscreteConditional::shared_ptr c = *it;
DecisionTreeFactor::shared_ptr factor = c->toFactor();
product = (*factor) * product;
}
// and then create a new multi-frontal conditional
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
}
} // gtsam

View File

@ -0,0 +1,44 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteEliminationTree.cpp
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/inference/EliminationTree-inst.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
namespace gtsam {
// Instantiate base class
template class EliminationTree<DiscreteBayesNet, DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteEliminationTree::DiscreteEliminationTree(
const DiscreteFactorGraph& factorGraph, const VariableIndex& structure,
const Ordering& order) :
Base(factorGraph, structure, order) {}
/* ************************************************************************* */
DiscreteEliminationTree::DiscreteEliminationTree(
const DiscreteFactorGraph& factorGraph, const Ordering& order) :
Base(factorGraph, order) {}
/* ************************************************************************* */
bool DiscreteEliminationTree::equals(const This& other, double tol) const
{
return Base::equals(other, tol);
}
}

View File

@ -0,0 +1,63 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteEliminationTree.h
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#pragma once
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/EliminationTree.h>
namespace gtsam {
class GTSAM_EXPORT DiscreteEliminationTree :
public EliminationTree<DiscreteBayesNet, DiscreteFactorGraph>
{
public:
typedef EliminationTree<DiscreteBayesNet, DiscreteFactorGraph> Base; ///< Base class
typedef DiscreteEliminationTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/**
* Build the elimination tree of a factor graph using pre-computed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
DiscreteEliminationTree(const DiscreteFactorGraph& factorGraph,
const VariableIndex& structure, const Ordering& order);
/** Build the elimination tree of a factor graph. Note that this has to compute the column
* structure as a VariableIndex, so if you already have this precomputed, use the other
* constructor instead.
* @param factorGraph The factor graph for which to build the elimination tree
*/
DiscreteEliminationTree(const DiscreteFactorGraph& factorGraph,
const Ordering& order);
/** Test whether the tree is equal to another */
bool equals(const This& other, double tol = 1e-9) const;
private:
friend class ::EliminationTreeTester;
};
}

View File

@ -24,9 +24,5 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ******************************************************************************** */
DiscreteFactor::DiscreteFactor() {
}
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -19,104 +19,86 @@
#pragma once #pragma once
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/inference/IndexFactorOrdered.h> #include <gtsam/inference/Factor.h>
namespace gtsam { namespace gtsam {
class DecisionTreeFactor; class DecisionTreeFactor;
class DiscreteConditional; class DiscreteConditional;
/** /**
* Base class for discrete probabilistic factors * Base class for discrete probabilistic factors
* The most general one is the derived DecisionTreeFactor * The most general one is the derived DecisionTreeFactor
*/
class GTSAM_EXPORT DiscreteFactor: public Factor {
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This; ///< This class
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
/** A map from keys to values
* TODO: Do we need this? Should we just use gtsam::Values?
* We just need another special DiscreteValue to represent labels,
* However, all other Lie's operators are undefined in this class.
* The good thing is we can have a Hybrid graph of discrete/continuous variables
* together..
* Another good thing is we don't need to have the special DiscreteKey which stores
* cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the varible's type (domain)
*/ */
class GTSAM_EXPORT DiscreteFactor: public IndexFactorOrdered { typedef Assignment<Index> Values;
typedef boost::shared_ptr<Values> sharedValues;
public: public:
// typedefs needed to play nice with gtsam /// @name Standard Constructors
typedef DiscreteFactor This; /// @{
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<DiscreteFactor> shared_ptr;
/** A map from keys to values */ /** Default constructor creates empty factor */
typedef Assignment<Index> Values; DiscreteFactor() {}
typedef boost::shared_ptr<Values> sharedValues;
protected: /** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
/// Construct n-way factor /// Virtual destructor
DiscreteFactor(const std::vector<Index>& js) : virtual ~DiscreteFactor() {
IndexFactorOrdered(js) { }
}
/// Construct unary factor /// @}
DiscreteFactor(Index j) : /// @name Testable
IndexFactorOrdered(j) { /// @{
}
/// Construct binary factor // equals
DiscreteFactor(Index j1, Index j2) : virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
IndexFactorOrdered(j1, j2) {
}
/// construct from container // print
template<class KeyIterator> virtual void print(const std::string& s = "DiscreteFactor\n",
DiscreteFactor(KeyIterator beginKey, KeyIterator endKey) : const IndexFormatter& formatter = DefaultIndexFormatter) const {
IndexFactorOrdered(beginKey, endKey) { Factor::print(s, formatter);
} }
public: /** Test whether the factor is empty */
virtual bool empty() const = 0;
/// @name Standard Constructors /// @}
/// @{ /// @name Standard Interface
/// @{
/// Default constructor for I/O /// Find value for given assignment of values to variables
DiscreteFactor(); virtual double operator()(const Values&) const = 0;
/// Virtual destructor /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual ~DiscreteFactor() {} virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
/// @} virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/// @name Testable
/// @{
// print /// @}
virtual void print(const std::string& s = "DiscreteFactor\n", };
const IndexFormatter& formatter
=DefaultIndexFormatter) const {
IndexFactorOrdered::print(s,formatter);
}
/// @}
/// @name Standard Interface
/// @{
/// Find value for given assignment of values to variables
virtual double operator()(const Values&) const = 0;
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/**
* Permutes the factor, but for efficiency requires the permutation
* to already be inverted.
*/
virtual void permuteWithInverse(const Permutation& inversePermutation){
IndexFactorOrdered::permuteWithInverse(inversePermutation);
}
/**
* Apply a reduction, which is a remapping of variable indices.
*/
virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
IndexFactorOrdered::reduceWithInverse(inverseReduction);
}
/// @}
};
// DiscreteFactor // DiscreteFactor
}// namespace gtsam }// namespace gtsam

View File

@ -19,23 +19,23 @@
//#define ENABLE_TIMING //#define ENABLE_TIMING
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/EliminationTreeOrdered-inl.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
// Explicitly instantiate so we don't have to include everywhere // Instantiate base classes
template class FactorGraphOrdered<DiscreteFactor> ; template class FactorGraph<DiscreteFactor>;
template class EliminationTreeOrdered<DiscreteFactor> ; template class EliminateableFactorGraph<DiscreteFactorGraph>;
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph() { bool DiscreteFactorGraph::equals(const This& fg, double tol) const
} {
return Base::equals(fg, tol);
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph(
const BayesNetOrdered<DiscreteConditional>& bayesNet) :
FactorGraphOrdered<DiscreteFactor>(bayesNet) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -75,27 +75,34 @@ namespace gtsam {
} }
} }
/* ************************************************************************* */ // /* ************************************************************************* */
void DiscreteFactorGraph::permuteWithInverse( // void DiscreteFactorGraph::permuteWithInverse(
const Permutation& inversePermutation) { // const Permutation& inversePermutation) {
BOOST_FOREACH(const sharedFactor& factor, factors_) { // BOOST_FOREACH(const sharedFactor& factor, factors_) {
if(factor) // if(factor)
factor->permuteWithInverse(inversePermutation); // factor->permuteWithInverse(inversePermutation);
} // }
} // }
//
// /* ************************************************************************* */
// void DiscreteFactorGraph::reduceWithInverse(
// const internal::Reduction& inverseReduction) {
// BOOST_FOREACH(const sharedFactor& factor, factors_) {
// if(factor)
// factor->reduceWithInverse(inverseReduction);
// }
// }
/* ************************************************************************* */ /* ************************************************************************* */
void DiscreteFactorGraph::reduceWithInverse( DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const
const internal::Reduction& inverseReduction) { {
BOOST_FOREACH(const sharedFactor& factor, factors_) { gttic(DiscreteFactorGraph_optimize);
if(factor) return BaseEliminateable::eliminateSequential()->optimize();
factor->reduceWithInverse(inverseReduction);
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const FactorGraphOrdered<DiscreteFactor>& factors, size_t num) { EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
@ -107,18 +114,22 @@ namespace gtsam {
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(num); DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum); gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
gttoc(divide); gttoc(divide);
return std::make_pair(cond, sum); return std::make_pair(cond, sum);
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace } // namespace

View File

@ -18,33 +18,85 @@
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/inference/FactorGraphOrdered.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
namespace gtsam { namespace gtsam {
class DiscreteFactorGraph: public FactorGraphOrdered<DiscreteFactor> { // Forward declarations
class DiscreteFactorGraph;
class DiscreteFactor;
class DiscreteConditional;
class DiscreteBayesNet;
class DiscreteEliminationTree;
class DiscreteBayesTree;
class DiscreteJunctionTree;
/** Main elimination function for DiscreteFactorGraph */
std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph>
{
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
typedef DiscreteFactorGraph FactorGraphType; ///< Type of the factor graph (e.g. DiscreteFactorGraph)
typedef DiscreteConditional ConditionalType; ///< Type of conditionals from elimination
typedef DiscreteBayesNet BayesNetType; ///< Type of Bayes net from sequential elimination
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); }
};
/* ************************************************************************* */
/**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
*/
class DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public: public:
typedef DiscreteFactorGraph This; ///< Typedef to this class
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
/** A map from keys to values */ /** A map from keys to values */
typedef std::vector<Index> Indices; typedef std::vector<Index> Indices;
typedef Assignment<Index> Values; typedef Assignment<Index> Values;
typedef boost::shared_ptr<Values> sharedValues; typedef boost::shared_ptr<Values> sharedValues;
/** Construct empty factor graph */ /** Default constructor */
GTSAM_EXPORT DiscreteFactorGraph(); DiscreteFactorGraph() {}
/** Constructor from a factor graph of GaussianFactor or a derived type */ /** Construct from iterator over factors */
template<typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDFACTOR> template<class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraphOrdered<DERIVEDFACTOR>& fg) { DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
push_back(fg);
}
/** construct from a BayesNet */ /// @name Testable
GTSAM_EXPORT DiscreteFactorGraph(const BayesNetOrdered<DiscreteConditional>& bayesNet); /// @{
bool equals(const This& fg, double tol = 1e-9) const;
/// @}
template<class SOURCE> template<class SOURCE>
void add(const DiscreteKey& j, SOURCE table) { void add(const DiscreteKey& j, SOURCE table) {
@ -79,18 +131,20 @@ public:
/// print /// print
GTSAM_EXPORT void print(const std::string& s = "DiscreteFactorGraph", GTSAM_EXPORT void print(const std::string& s = "DiscreteFactorGraph",
const IndexFormatter& formatter =DefaultIndexFormatter) const; const IndexFormatter& formatter =DefaultIndexFormatter) const;
/** Solve the factor graph by performing variable elimination in COLAMD order using
* the dense elimination function specified in \c function,
* followed by back-substitution resulting from elimination. Is equivalent
* to calling graph.eliminateSequential()->optimize(). */
DiscreteFactor::sharedValues optimize() const;
/** Permute the variables in the factors */ // /** Permute the variables in the factors */
GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); // GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
//
/** Apply a reduction, which is a remapping of variable indices. */ // /** Apply a reduction, which is a remapping of variable indices. */
GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
}; };
// DiscreteFactorGraph // DiscreteFactorGraph
/** Main elimination function for DiscreteFactorGraph */
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const FactorGraphOrdered<DiscreteFactor>& factors,
size_t nrFrontals = 1);
} // namespace gtsam } // namespace gtsam

View File

@ -0,0 +1,34 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteJunctionTree.cpp
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/inference/JunctionTree-inst.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
namespace gtsam {
// Instantiate base classes
template class ClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteJunctionTree::DiscreteJunctionTree(
const DiscreteEliminationTree& eliminationTree) :
Base(eliminationTree) {}
}

View File

@ -0,0 +1,66 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteJunctionTree.h
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/inference/JunctionTree.h>
namespace gtsam {
// Forward declarations
class DiscreteEliminationTree;
/**
* A ClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, with
* the additional property that it represents the clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k represents
* a clique (maximal fully connected subset) of an associated chordal graph, such as a
* chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors, whereas a
* BayesTree stores conditionals, that are the product of eliminating the factors in the
* corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analagous to the EliminationTree,
* except that in the JunctionTree, at each node multiple variables are eliminated at a time.
*
* \addtogroup Multifrontal
* \nosubgrouping
*/
class GTSAM_EXPORT DiscreteJunctionTree :
public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
public:
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> Base; ///< Base class
typedef DiscreteJunctionTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/**
* Build the elimination tree of a factor graph using pre-computed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};
}

View File

@ -21,7 +21,8 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/GenericMultifrontalSolver.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/base/Vector.h>
namespace gtsam { namespace gtsam {
@ -32,7 +33,7 @@ namespace gtsam {
protected: protected:
BayesTreeOrdered<DiscreteConditional> bayesTree_; DiscreteBayesTree::shared_ptr bayesTree_;
public: public:
@ -40,16 +41,14 @@ namespace gtsam {
* @param graph The factor graph defining the full joint density on all variables. * @param graph The factor graph defining the full joint density on all variables.
*/ */
DiscreteMarginals(const DiscreteFactorGraph& graph) { DiscreteMarginals(const DiscreteFactorGraph& graph) {
typedef JunctionTreeOrdered<DiscreteFactorGraph> DiscreteJT; bayesTree_ = graph.eliminateMultifrontal();
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
} }
/** Compute the marginal of a single variable */ /** Compute the marginal of a single variable */
DiscreteFactor::shared_ptr operator()(Index variable) const { DiscreteFactor::shared_ptr operator()(Index variable) const {
// Compute marginal // Compute marginal
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete); marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete);
return marginalFactor; return marginalFactor;
} }
@ -60,7 +59,7 @@ namespace gtsam {
Vector marginalProbabilities(const DiscreteKey& key) const { Vector marginalProbabilities(const DiscreteKey& key) const {
// Compute marginal // Compute marginal
DiscreteFactor::shared_ptr marginalFactor; DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete); marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
//Create result //Create result
Vector vResult(key.second); Vector vResult(key.second);

View File

@ -1,75 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteSequentialSolver.cpp
* @date Feb 16, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
//#define ENABLE_TIMING
#include <gtsam/discrete/DiscreteSequentialSolver.h>
#include <gtsam/inference/GenericSequentialSolver-inl.h>
#include <gtsam/base/timing.h>
namespace gtsam {
template class GenericSequentialSolver<DiscreteFactor> ;
/* ************************************************************************* */
DiscreteFactor::sharedValues DiscreteSequentialSolver::optimize() const {
static const bool debug = false;
if (debug) this->factors_->print("DiscreteSequentialSolver, eliminating ");
if (debug) this->eliminationTree_->print(
"DiscreteSequentialSolver, elimination tree ");
// Eliminate using the elimination tree
gttic(eliminate);
DiscreteBayesNet::shared_ptr bayesNet = eliminate();
gttoc(eliminate);
if (debug) bayesNet->print("DiscreteSequentialSolver, Bayes net ");
// Allocate the solution vector if it is not already allocated
// Back-substitute
gttic(optimize);
DiscreteFactor::sharedValues solution = gtsam::optimize(*bayesNet);
gttoc(optimize);
if (debug) solution->print("DiscreteSequentialSolver, solution ");
return solution;
}
/* ************************************************************************* */
Vector DiscreteSequentialSolver::marginalProbabilities(
const DiscreteKey& key) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete);
//Create result
Vector vResult(key.second);
for (size_t state = 0; state < key.second; ++state) {
DiscreteFactor::Values values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}
return vResult;
}
/* ************************************************************************* */
}

View File

@ -1,113 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteSequentialSolver.h
* @date Feb 16, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/GenericSequentialSolver.h>
#include <boost/shared_ptr.hpp>
namespace gtsam {
// The base class provides all of the needed functionality
class DiscreteSequentialSolver: public GenericSequentialSolver<DiscreteFactor> {
protected:
typedef GenericSequentialSolver<DiscreteFactor> Base;
typedef boost::shared_ptr<const DiscreteSequentialSolver> shared_ptr;
public:
/**
* The problem we are trying to solve (SUM or MPE).
*/
typedef enum {
BEL, // Belief updating (or conditional updating)
MPE, // Most-Probable-Explanation
MAP
// Maximum A Posteriori hypothesis
} ProblemType;
/**
* Construct the solver for a factor graph. This builds the elimination
* tree, which already does some of the work of elimination.
*/
DiscreteSequentialSolver(const FactorGraphOrdered<DiscreteFactor>& factorGraph) :
Base(factorGraph) {
}
/**
* Construct the solver with a shared pointer to a factor graph and to a
* VariableIndex. The solver will store these pointers, so this constructor
* is the fastest.
*/
DiscreteSequentialSolver(
const FactorGraphOrdered<DiscreteFactor>::shared_ptr& factorGraph,
const VariableIndexOrdered::shared_ptr& variableIndex) :
Base(factorGraph, variableIndex) {
}
const EliminationTreeOrdered<DiscreteFactor>& eliminationTree() const {
return *eliminationTree_;
}
/**
* Eliminate the factor graph sequentially. Uses a column elimination tree
* to recursively eliminate.
*/
BayesNetOrdered<DiscreteConditional>::shared_ptr eliminate() const {
return Base::eliminate(&EliminateDiscrete);
}
/**
* Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. This function returns the result as a factor
* graph.
*/
DiscreteFactorGraph::shared_ptr jointFactorGraph(
const std::vector<Index>& js) const {
DiscreteFactorGraph::shared_ptr results(new DiscreteFactorGraph(
*Base::jointFactorGraph(js, &EliminateDiscrete)));
return results;
}
/**
* Compute the marginal density over a variable, by integrating out
* all of the other variables. This function returns the result as a factor.
*/
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
return Base::marginalFactor(j, &EliminateDiscrete);
}
/**
* Compute the marginal density over a variable, by integrating out
* all of the other variables. This function returns the result as a
* Vector of the probability values.
*/
GTSAM_EXPORT Vector marginalProbabilities(const DiscreteKey& key) const;
/**
* Compute the MPE solution of the DiscreteFactorGraph. This
* eliminates to create a BayesNet and then back-substitutes this BayesNet to
* obtain the solution.
*/
GTSAM_EXPORT DiscreteFactor::sharedValues optimize() const;
};
} // gtsam

View File

@ -59,39 +59,39 @@ namespace gtsam {
cout << endl; cout << endl;
ADT::print(" "); ADT::print(" ");
} }
//
/* ************************************************************************* */ // /* ************************************************************************* */
template<class P> // template<class P>
void Potentials::remapIndices(const P& remapping) { // void Potentials::remapIndices(const P& remapping) {
// Permute the _cardinalities (TODO: Inefficient Consider Improving) // // Permute the _cardinalities (TODO: Inefficient Consider Improving)
DiscreteKeys keys; // DiscreteKeys keys;
map<Index, Index> ordering; // map<Index, Index> ordering;
//
// Get the original keys from cardinalities_ // // Get the original keys from cardinalities_
BOOST_FOREACH(const DiscreteKey& key, cardinalities_) // BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
keys & key; // keys & key;
//
// Perform Permutation // // Perform Permutation
BOOST_FOREACH(DiscreteKey& key, keys) { // BOOST_FOREACH(DiscreteKey& key, keys) {
ordering[key.first] = remapping[key.first]; // ordering[key.first] = remapping[key.first];
key.first = ordering[key.first]; // key.first = ordering[key.first];
} // }
//
// Change *this // // Change *this
AlgebraicDecisionTree<Index> permuted((*this), ordering); // AlgebraicDecisionTree<Index> permuted((*this), ordering);
*this = permuted; // *this = permuted;
cardinalities_ = keys.cardinalities(); // cardinalities_ = keys.cardinalities();
} // }
//
/* ************************************************************************* */ // /* ************************************************************************* */
void Potentials::permuteWithInverse(const Permutation& inversePermutation) { // void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
remapIndices(inversePermutation); // remapIndices(inversePermutation);
} // }
//
/* ************************************************************************* */ // /* ************************************************************************* */
void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) { // void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
remapIndices(inverseReduction); // remapIndices(inverseReduction);
} // }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -19,7 +19,6 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/PermutationOrdered.h>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <set> #include <set>
@ -48,9 +47,9 @@ namespace gtsam {
// Safe division for probabilities // Safe division for probabilities
GTSAM_EXPORT static double safe_div(const double& a, const double& b); GTSAM_EXPORT static double safe_div(const double& a, const double& b);
// Apply either a permutation or a reduction // // Apply either a permutation or a reduction
template<class P> // template<class P>
void remapIndices(const P& remapping); // void remapIndices(const P& remapping);
public: public:
@ -73,19 +72,19 @@ namespace gtsam {
size_t cardinality(Index j) const { return cardinalities_.at(j);} size_t cardinality(Index j) const { return cardinalities_.at(j);}
/** // /**
* @brief Permutes the keys in Potentials // * @brief Permutes the keys in Potentials
* // *
* This permutes the Indices and performs necessary re-ordering of ADD. // * This permutes the Indices and performs necessary re-ordering of ADD.
* This is virtual so that derived types e.g. DecisionTreeFactor can // * This is virtual so that derived types e.g. DecisionTreeFactor can
* re-implement it. // * re-implement it.
*/ // */
GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation); // GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
//
/** // /**
* Apply a reduction, which is a remapping of variable indices. // * Apply a reduction, which is a remapping of variable indices.
*/ // */
GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction); // GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
}; // Potentials }; // Potentials

View File

@ -12,7 +12,7 @@
/** /**
* @file Signature.cpp * @file Signature.cpp
* @brief signatures for conditional densities * @brief signatures for conditional densities
* @author Frank dellaert * @author Frank Dellaert
* @date Feb 27, 2011 * @date Feb 27, 2011
*/ */

View File

@ -12,7 +12,7 @@
/** /**
* @file Signature.h * @file Signature.h
* @brief signatures for conditional densities * @brief signatures for conditional densities
* @author Frank dellaert * @author Frank Dellaert
* @date Feb 27, 2011 * @date Feb 27, 2011
*/ */

View File

@ -17,13 +17,15 @@
*/ */
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteSequentialSolver.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h> #include <gtsam/base/debug.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/list_inserter.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <iostream> #include <iostream>
@ -40,38 +42,40 @@ TEST(DiscreteBayesNet, Asia)
DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2); DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2);
// TODO: make a version that doesn't use the parser // TODO: make a version that doesn't use the parser
add_front(asia, A % "99/1"); asia.add(A % "99/1");
add_front(asia, S % "50/50"); asia.add(S % "50/50");
add_front(asia, T | A = "99/1 95/5"); asia.add(T | A = "99/1 95/5");
add_front(asia, L | S = "99/1 90/10"); asia.add(L | S = "99/1 90/10");
add_front(asia, B | S = "70/30 40/60"); asia.add(B | S = "70/30 40/60");
add_front(asia, (E | T, L) = "F T T T"); asia.add((E | T, L) = "F T T T");
add_front(asia, X | E = "95/5 2/98"); asia.add(X | E = "95/5 2/98");
// next lines are same as add_front(asia, (D | E, B) = "9/1 2/8 3/7 1/9"); // next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9");
DiscreteConditional::shared_ptr actual = DiscreteConditional::shared_ptr actual =
boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9"); boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9");
asia.push_front(actual); asia.push_back(actual);
// GTSAM_PRINT(asia); // GTSAM_PRINT(asia);
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph fg(asia); DiscreteFactorGraph fg(asia);
// GTSAM_PRINT(fg); // GTSAM_PRINT(fg);
LONGS_EQUAL(3,fg.front()->size()); LONGS_EQUAL(3,fg.back()->size());
Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9"); Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9");
CHECK(assert_equal(expected,(Potentials::ADT)*actual)); CHECK(assert_equal(expected,(Potentials::ADT)*actual));
// Create solver and eliminate // Create solver and eliminate
DiscreteSequentialSolver solver(fg); Ordering ordering;
DiscreteBayesNet::shared_ptr chordal = solver.eliminate(); ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7);
// GTSAM_PRINT(*chordal); ordering.print("ordering: ");
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// GTSAM_PRINT(*chordal);
DiscreteConditional expected2(B % "11/9"); DiscreteConditional expected2(B % "11/9");
CHECK(assert_equal(expected2,*chordal->back())); CHECK(assert_equal(expected2,*chordal->back()));
// solve // solve
DiscreteFactor::sharedValues actualMPE = optimize(*chordal); DiscreteFactor::sharedValues actualMPE = chordal->optimize();
DiscreteFactor::Values expectedMPE; DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first, insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first,
0)(E.first, 0)(L.first, 0)(B.first, 0); 0)(E.first, 0)(L.first, 0)(B.first, 0);
@ -83,10 +87,10 @@ TEST(DiscreteBayesNet, Asia)
// fg.product().dot("fg"); // fg.product().dot("fg");
// solve again, now with evidence // solve again, now with evidence
DiscreteSequentialSolver solver2(fg); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential();
DiscreteBayesNet::shared_ptr chordal2 = solver2.eliminate(); GTSAM_PRINT(*chordal2);
// GTSAM_PRINT(*chordal2); DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
DiscreteFactor::sharedValues actualMPE2 = optimize(*chordal2); cout << "optimized!" << endl;
DiscreteFactor::Values expectedMPE2; DiscreteFactor::Values expectedMPE2;
insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first, insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first,
1)(E.first, 0)(L.first, 0)(B.first, 1); 1)(E.first, 0)(L.first, 0)(B.first, 1);
@ -97,8 +101,8 @@ TEST(DiscreteBayesNet, Asia)
SETDEBUG("DiscreteConditional::sample", false); SETDEBUG("DiscreteConditional::sample", false);
insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)( insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(
S.first, 1)(E.first, 0)(L.first, 0)(B.first, 1); S.first, 1)(E.first, 0)(L.first, 0)(B.first, 1);
DiscreteFactor::sharedValues actualSample = sample(*chordal2); DiscreteFactor::sharedValues actualSample = chordal->sample();
EXPECT(assert_equal(expectedSample, *actualSample)); // EXPECT(assert_equal(expectedSample, *actualSample));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -114,12 +118,12 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar)
// add(bn, D | E = "blah"); // add(bn, D | E = "blah");
// try logic // try logic
add(bn, (E | T, L) = "OR"); bn.add((E | T, L) = "OR");
add(bn, (E | T, L) = "AND"); bn.add((E | T, L) = "AND");
// // try multivalued // // try multivalued
add(bn, C % "1/1/2"); bn.add(C % "1/1/2");
add(bn, C | S = "1/1/2 5/2/3"); bn.add(C | S = "1/1/2 5/2/3");
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -1,261 +1,261 @@
/* ---------------------------------------------------------------------------- ///* ----------------------------------------------------------------------------
//
* GTSAM Copyright 2010, Georgia Tech Research Corporation, // * GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415 // * Atlanta, Georgia 30332-0415
* All Rights Reserved // * All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) // * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
//
* See LICENSE for the license information // * See LICENSE for the license information
//
* -------------------------------------------------------------------------- */ // * -------------------------------------------------------------------------- */
//
/* ///*
* @file testDiscreteBayesTree.cpp // * @file testDiscreteBayesTree.cpp
* @date sept 15, 2012 // * @date sept 15, 2012
* @author Frank Dellaert // * @author Frank Dellaert
*/ // */
//
#include <gtsam/discrete/DiscreteBayesNet.h> //#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> //#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/inference/BayesTreeOrdered.h> //#include <gtsam/discrete/DiscreteFactorGraph.h>
//
#include <boost/assign/std/vector.hpp> //#include <boost/assign/std/vector.hpp>
using namespace boost::assign; //using namespace boost::assign;
//
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
//
using namespace std; //using namespace std;
using namespace gtsam; //using namespace gtsam;
//
static bool debug = false; //static bool debug = false;
//
/** ///**
* Custom clique class to debug shortcuts // * Custom clique class to debug shortcuts
*/ // */
class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> { ////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
////
protected: ////protected:
////
public: ////public:
////
typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base; //// typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
typedef boost::shared_ptr<Clique> shared_ptr; //// typedef boost::shared_ptr<Clique> shared_ptr;
////
// Constructors //// // Constructors
Clique() { //// Clique() {
} //// }
Clique(const DiscreteConditional::shared_ptr& conditional) : //// Clique(const DiscreteConditional::shared_ptr& conditional) :
Base(conditional) { //// Base(conditional) {
} //// }
Clique( //// Clique(
const std::pair<DiscreteConditional::shared_ptr, //// const std::pair<DiscreteConditional::shared_ptr,
DiscreteConditional::FactorType::shared_ptr>& result) : //// DiscreteConditional::FactorType::shared_ptr>& result) :
Base(result) { //// Base(result) {
} //// }
////
/// print index signature only //// /// print index signature only
void printSignature(const std::string& s = "Clique: ", //// void printSignature(const std::string& s = "Clique: ",
const IndexFormatter& indexFormatter = DefaultIndexFormatter) const { //// const IndexFormatter& indexFormatter = DefaultIndexFormatter) const {
((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter); //// ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
} //// }
////
/// evaluate value of sub-tree //// /// evaluate value of sub-tree
double evaluate(const DiscreteConditional::Values & values) { //// double evaluate(const DiscreteConditional::Values & values) {
double result = (*(this->conditional_))(values); //// double result = (*(this->conditional_))(values);
// evaluate all children and multiply into result //// // evaluate all children and multiply into result
BOOST_FOREACH(boost::shared_ptr<Clique> c, children_) //// BOOST_FOREACH(boost::shared_ptr<Clique> c, children_)
result *= c->evaluate(values); //// result *= c->evaluate(values);
return result; //// return result;
} //// }
////
}; ////};
//
typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree; ////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
////
/* ************************************************************************* */ /////* ************************************************************************* */
double evaluate(const DiscreteBayesTree& tree, ////double evaluate(const DiscreteBayesTree& tree,
const DiscreteConditional::Values & values) { //// const DiscreteConditional::Values & values) {
return tree.root()->evaluate(values); //// return tree.root()->evaluate(values);
} ////}
//
/* ************************************************************************* */ ///* ************************************************************************* */
//
TEST_UNSAFE( DiscreteBayesTree, thinTree ) { //TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
//
const int nrNodes = 15; // const int nrNodes = 15;
const size_t nrStates = 2; // const size_t nrStates = 2;
//
// define variables // // define variables
vector<DiscreteKey> key; // vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) { // for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates); // DiscreteKey key_i(i, nrStates);
key.push_back(key_i); // key.push_back(key_i);
} // }
//
// create a thin-tree Bayesnet, a la Jean-Guillaume // // create a thin-tree Bayesnet, a la Jean-Guillaume
DiscreteBayesNet bayesNet; // DiscreteBayesNet bayesNet;
add_front(bayesNet, key[14] % "1/3"); // bayesNet.add(key[14] % "1/3");
//
add_front(bayesNet, key[13] | key[14] = "1/3 3/1"); // bayesNet.add(key[13] | key[14] = "1/3 3/1");
add_front(bayesNet, key[12] | key[14] = "3/1 3/1"); // bayesNet.add(key[12] | key[14] = "3/1 3/1");
//
add_front(bayesNet, (key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1"); // bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1"); // bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 F 1/4"); // bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
add_front(bayesNet, (key[8] | key[12], key[14]) = "T 1/4 3/2 4/1"); // bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
//
add_front(bayesNet, (key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1"); // bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1"); // bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
add_front(bayesNet, (key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4"); // bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
add_front(bayesNet, (key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1"); // bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
//
add_front(bayesNet, (key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1"); // bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1"); // bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
add_front(bayesNet, (key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4"); // bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
add_front(bayesNet, (key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1"); // bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
//
if (debug) { //// if (debug) {
GTSAM_PRINT(bayesNet); //// GTSAM_PRINT(bayesNet);
bayesNet.saveGraph("/tmp/discreteBayesNet.dot"); //// bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
} //// }
//
// create a BayesTree out of a Bayes net // // create a BayesTree out of a Bayes net
DiscreteBayesTree bayesTree(bayesNet); // DiscreteBayesTree bayesTree(bayesNet);
if (debug) { // if (debug) {
GTSAM_PRINT(bayesTree); // GTSAM_PRINT(bayesTree);
bayesTree.saveGraph("/tmp/discreteBayesTree.dot"); // bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
} // }
//
// Check whether BN and BT give the same answer on all configurations // // Check whether BN and BT give the same answer on all configurations
// Also calculate all some marginals // // Also calculate all some marginals
Vector marginals = zero(15); // Vector marginals = zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0, // double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0, // joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0; // joint_4_11 = 0;
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct( // vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7] // key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]); // & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) { // for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i]; // DiscreteFactor::Values x = allPosbValues[i];
double expected = evaluate(bayesNet, x); // double expected = evaluate(bayesNet, x);
double actual = evaluate(bayesTree, x); // double actual = evaluate(bayesTree, x);
DOUBLES_EQUAL(expected, actual, 1e-9); // DOUBLES_EQUAL(expected, actual, 1e-9);
// collect marginals // // collect marginals
for (size_t i = 0; i < 15; i++) // for (size_t i = 0; i < 15; i++)
if (x[i]) // if (x[i])
marginals[i] += actual; // marginals[i] += actual;
// calculate shortcut 8 and 0 // // calculate shortcut 8 and 0
if (x[12] && x[14]) // if (x[12] && x[14])
joint_12_14 += actual; // joint_12_14 += actual;
if (x[9] && x[12] & x[14]) // if (x[9] && x[12] & x[14])
joint_9_12_14 += actual; // joint_9_12_14 += actual;
if (x[8] && x[12] & x[14]) // if (x[8] && x[12] & x[14])
joint_8_12_14 += actual; // joint_8_12_14 += actual;
if (x[8] && x[12]) // if (x[8] && x[12])
joint_8_12 += actual; // joint_8_12 += actual;
if (x[8] && x[2]) // if (x[8] && x[2])
joint82 += actual; // joint82 += actual;
if (x[1] && x[2]) // if (x[1] && x[2])
joint12 += actual; // joint12 += actual;
if (x[2] && x[4]) // if (x[2] && x[4])
joint24 += actual; // joint24 += actual;
if (x[4] && x[5]) // if (x[4] && x[5])
joint45 += actual; // joint45 += actual;
if (x[4] && x[6]) // if (x[4] && x[6])
joint46 += actual; // joint46 += actual;
if (x[4] && x[11]) // if (x[4] && x[11])
joint_4_11 += actual; // joint_4_11 += actual;
} // }
DiscreteFactor::Values all1 = allPosbValues.back(); // DiscreteFactor::Values all1 = allPosbValues.back();
//
Clique::shared_ptr R = bayesTree.root(); // Clique::shared_ptr R = bayesTree.root();
//
// check separator marginal P(S0) // // check separator marginal P(S0)
Clique::shared_ptr c = bayesTree[0]; // Clique::shared_ptr c = bayesTree[0];
DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R, // DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
EliminateDiscrete); // EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9); // EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
//
// check separator marginal P(S9), should be P(14) // // check separator marginal P(S9), should be P(14)
c = bayesTree[9]; // c = bayesTree[9];
DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R, // DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
EliminateDiscrete); // EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9); // EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
//
// check separator marginal of root, should be empty // // check separator marginal of root, should be empty
c = bayesTree[11]; // c = bayesTree[11];
DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R, // DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
EliminateDiscrete); // EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, separatorMarginal11.size()); // EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
//
// check shortcut P(S9||R) to root // // check shortcut P(S9||R) to root
c = bayesTree[9]; // c = bayesTree[9];
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); // DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, shortcut.size()); // EXPECT_LONGS_EQUAL(0, shortcut.size());
//
// check shortcut P(S8||R) to root // // check shortcut P(S8||R) to root
c = bayesTree[8]; // c = bayesTree[8];
shortcut = c->shortcut(R, EliminateDiscrete); // shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1), // EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
1e-9); // 1e-9);
//
// check shortcut P(S2||R) to root // // check shortcut P(S2||R) to root
c = bayesTree[2]; // c = bayesTree[2];
shortcut = c->shortcut(R, EliminateDiscrete); // shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1), // EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
1e-9); // 1e-9);
//
// check shortcut P(S0||R) to root // // check shortcut P(S0||R) to root
c = bayesTree[0]; // c = bayesTree[0];
shortcut = c->shortcut(R, EliminateDiscrete); // shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1), // EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
1e-9); // 1e-9);
//
// calculate all shortcuts to root // // calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree.nodes(); // DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
BOOST_FOREACH(Clique::shared_ptr c, cliques) { // BOOST_FOREACH(Clique::shared_ptr c, cliques) {
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete); // DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
if (debug) { // if (debug) {
c->printSignature(); // c->printSignature();
shortcut.print("shortcut:"); // shortcut.print("shortcut:");
} // }
} // }
//
// Check all marginals // // Check all marginals
DiscreteFactor::shared_ptr marginalFactor; // DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) { // for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete); // marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1); // double actual = (*marginalFactor)(all1);
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9); // EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
} // }
//
DiscreteBayesNet::shared_ptr actualJoint; // DiscreteBayesNet::shared_ptr actualJoint;
//
// Check joint P(8,2) TODO: not disjoint ! // // Check joint P(8,2) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete); //// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9); //// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
//
// Check joint P(1,2) TODO: not disjoint ! // // Check joint P(1,2) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete); //// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9); //// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
//
// Check joint P(2,4) // // Check joint P(2,4)
actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete); // actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9); // EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
//
// Check joint P(4,5) TODO: not disjoint ! // // Check joint P(4,5) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete); //// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9); //// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// Check joint P(4,6) TODO: not disjoint ! // // Check joint P(4,6) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete); //// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9); //// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// Check joint P(4,11) // // Check joint P(4,11)
actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete); // actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9); // EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
//
} //}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {

View File

@ -17,8 +17,8 @@
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteSequentialSolver.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/inference/GenericMultifrontalSolver.h> #include <gtsam/discrete/DiscreteBayesTree.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
@ -58,9 +58,9 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
// result.first->print("BayesNet: "); // result.first->print("BayesNet: ");
// result.second->print("New factor: "); // result.second->print("New factor: ");
// //
DiscreteSequentialSolver solver(graph); Ordering ordering;
// solver.print("solver:"); ordering += Key(0),Key(1),Key(2),Key(3);
EliminationTreeOrdered<DiscreteFactor> eliminationTree = solver.eliminationTree(); DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: "); // eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete); eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize(); // solver.optimize();
@ -127,10 +127,11 @@ TEST( DiscreteFactorGraph, test)
// Test EliminateDiscrete // Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net // FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys;
frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional; DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor; DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) =// boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
EliminateDiscrete(graph, 1);
// Check Bayes net // Check Bayes net
CHECK(conditional); CHECK(conditional);
@ -139,34 +140,35 @@ TEST( DiscreteFactorGraph, test)
// cout << signature << endl; // cout << signature << endl;
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
add(expected, signature); expected.add(signature);
// Check Factor // Check Factor
CHECK(newFactor); CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); EXPECT(assert_equal(expectedFactor, *newFactor));
// Create solver
DiscreteSequentialSolver solver(graph);
// add conditionals to complete expected Bayes net // add conditionals to complete expected Bayes net
add(expected, B | A = "5/3 3/5"); expected.add(B | A = "5/3 3/5");
add(expected, A % "1/1"); expected.add(A % "1/1");
// GTSAM_PRINT(expected); // GTSAM_PRINT(expected);
// Test elimination tree // Test elimination tree
const EliminationTreeOrdered<DiscreteFactor>& etree = solver.eliminationTree(); Ordering ordering;
DiscreteBayesNet::shared_ptr actual = etree.eliminate(&EliminateDiscrete); ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual)); EXPECT(assert_equal(expected, *actual));
// Test solver // // Test solver
DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); // DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
EXPECT(assert_equal(expected, *actual2)); // EXPECT(assert_equal(expected, *actual2));
// Test optimization // Test optimization
DiscreteFactor::Values expectedValues; DiscreteFactor::Values expectedValues;
insert(expectedValues)(0, 0)(1, 0)(2, 0); insert(expectedValues)(0, 0)(1, 0)(2, 0);
DiscreteFactor::sharedValues actualValues = solver.optimize(); DiscreteFactor::sharedValues actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, *actualValues)); EXPECT(assert_equal(expectedValues, *actualValues));
} }
@ -183,8 +185,7 @@ TEST( DiscreteFactorGraph, testMPE)
// graph.product().print(); // graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print(); // DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteFactor::sharedValues actualMPE = DiscreteFactor::sharedValues actualMPE = graph.optimize();
DiscreteSequentialSolver(graph).optimize();
DiscreteFactor::Values expectedMPE; DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(0, 0)(1, 1)(2, 1); insert(expectedMPE)(0, 0)(1, 1)(2, 1);
@ -213,18 +214,23 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
// Use the solver machinery. // Use the solver machinery.
DiscreteBayesNet::shared_ptr chordal = DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
DiscreteSequentialSolver(graph).eliminate(); DiscreteFactor::sharedValues actualMPE = chordal->optimize();
DiscreteFactor::sharedValues actualMPE = optimize(*chordal);
EXPECT(assert_equal(expectedMPE, *actualMPE)); EXPECT(assert_equal(expectedMPE, *actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back(); // DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); // EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now // Let us create the Bayes tree here, just for fun, because we don't use it now
typedef JunctionTreeOrdered<DiscreteFactorGraph> JT; // typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph); // GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete); // BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
// bayesTree->print("Bayes Tree"); //// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2,bayesTree->size()); EXPECT_LONGS_EQUAL(2,bayesTree->size());
#ifdef OLD #ifdef OLD

View File

@ -18,7 +18,6 @@
*/ */
#include <gtsam/discrete/DiscreteMarginals.h> #include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/discrete/DiscreteSequentialSolver.h>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
using namespace boost::assign; using namespace boost::assign;
@ -119,11 +118,9 @@ TEST_UNSAFE( DiscreteMarginals, truss ) {
graph.add(key[0] & key[2] & key[4],"2 2 2 2 1 1 1 1"); graph.add(key[0] & key[2] & key[4],"2 2 2 2 1 1 1 1");
graph.add(key[1] & key[3] & key[4],"1 1 1 1 2 2 2 2"); graph.add(key[1] & key[3] & key[4],"1 1 1 1 2 2 2 2");
graph.add(key[2] & key[3] & key[4],"1 1 1 1 1 1 1 1"); graph.add(key[2] & key[3] & key[4],"1 1 1 1 1 1 1 1");
typedef JunctionTreeOrdered<DiscreteFactorGraph> JT; DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal();
GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
// bayesTree->print("Bayes Tree"); // bayesTree->print("Bayes Tree");
typedef BayesTreeClique<DiscreteConditional> Clique; typedef DiscreteBayesTreeClique Clique;
Clique expected0(boost::make_shared<DiscreteConditional>((key[0] | key[2], key[4]) = "2/1 2/1 2/1 2/1")); Clique expected0(boost::make_shared<DiscreteConditional>((key[0] | key[2], key[4]) = "2/1 2/1 2/1 2/1"));
Clique::shared_ptr actual0 = (*bayesTree)[0]; Clique::shared_ptr actual0 = (*bayesTree)[0];
@ -180,16 +177,16 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) {
} }
// Check all marginals given by a sequential solver and Marginals // Check all marginals given by a sequential solver and Marginals
DiscreteSequentialSolver solver(graph); // DiscreteSequentialSolver solver(graph);
DiscreteMarginals marginals(graph); DiscreteMarginals marginals(graph);
for (size_t j=0;j<5;j++) { for (size_t j=0;j<5;j++) {
double sum = T[j]+F[j]; double sum = T[j]+F[j];
T[j]/=sum; T[j]/=sum;
F[j]/=sum; F[j]/=sum;
// solver // // solver
Vector actualV = solver.marginalProbabilities(key[j]); // Vector actualV = solver.marginalProbabilities(key[j]);
EXPECT(assert_equal(Vector_(2,F[j],T[j]), actualV)); // EXPECT(assert_equal(Vector_(2,F[j],T[j]), actualV));
// Marginals // Marginals
vector<double> table; vector<double> table;