revive discrete, except testDiscreteBayesNet::sample, and the whole testDiscreteBayesTree, currently disabled. It's still a mess!!!

release/4.3a0
Duy-Nguyen Ta 2013-10-10 21:59:49 +00:00
parent 1a8a670870
commit 9aede467d2
30 changed files with 1347 additions and 1042 deletions

View File

@ -6,7 +6,7 @@ set (gtsam_subdirs
geometry
inference
symbolic
#discrete
discrete
linear
nonlinear
slam
@ -71,7 +71,7 @@ set(gtsam_srcs
${geometry_srcs}
${inference_srcs}
${symbolic_srcs}
#${discrete_srcs}
${discrete_srcs}
${linear_srcs}
${nonlinear_srcs}
${slam_srcs}

View File

@ -643,6 +643,7 @@ namespace gtsam {
// where there is no more branch on "label": only the subtree under that
// branch point corresponding to the value "index" is left instead.
// The function below get all these smaller trees and "ops" them together.
// This implements marginalization in Darwiche09book, pg 330
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::combine(const L& label,
size_t cardinality, const Binary& op) const {

View File

@ -177,7 +177,8 @@ namespace gtsam {
/** apply binary operation "op" to f and g */
DecisionTree apply(const DecisionTree& g, const Binary& op) const;
/** create a new function where value(label)==index */
/** create a new function where value(label)==index
* It's like "restrict" in Darwiche09book pg329, 330? */
DecisionTree choose(const L& label, size_t index) const {
NodePtr newRoot = root_->choose(label, index);
return DecisionTree(newRoot);

View File

@ -44,21 +44,25 @@ namespace gtsam {
}
/* ************************************************************************* */
bool DecisionTreeFactor::equals(const This& other, double tol) const {
return IndexFactorOrdered::equals(other, tol) && Potentials::equals(other, tol);
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
if(!dynamic_cast<const DecisionTreeFactor*>(&other))
return false;
else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol);
}
}
/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const IndexFormatter& formatter) const {
cout << s;
IndexFactorOrdered::print("IndexFactor:",formatter);
Potentials::print("Potentials:",formatter);
}
/* ************************************************************************* */
DecisionTreeFactor DecisionTreeFactor::apply //
(const DecisionTreeFactor& f, ADT::Binary op) const {
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
map<Index,size_t> cs; // new cardinalities
// make unique key-cardinality map
BOOST_FOREACH(Index j, keys()) cs[j] = cardinality(j);
@ -74,8 +78,8 @@ namespace gtsam {
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine //
(size_t nrFrontals, ADT::Binary op) const {
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
ADT::Binary op) const {
if (nrFrontals > size()) throw invalid_argument(
(boost::format(
@ -99,5 +103,36 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
ADT::Binary op) const {
if (frontalKeys.size() > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% frontalKeys.size() % size()).str());
// sum over nrFrontals keys
size_t i;
ADT result(*this);
for (i = 0; i < frontalKeys.size(); i++) {
Index j = frontalKeys[i];
result = result.combine(j, cardinality(j), op);
}
// create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) {
Index j = keys()[i];
// TODO: inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
continue;
dkeys.push_back(DiscreteKey(j,cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h>
#include <gtsam/inference/Ordering.h>
#include <boost/foreach.hpp>
#include <boost/shared_ptr.hpp>
@ -41,7 +42,7 @@ namespace gtsam {
// typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This;
typedef DiscreteConditional ConditionalType;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
public:
@ -69,7 +70,7 @@ namespace gtsam {
/// @{
/// equality
bool equals(const DecisionTreeFactor& other, double tol = 1e-9) const;
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
// print
virtual void print(const std::string& s = "DecisionTreeFactor:\n",
@ -104,6 +105,11 @@ namespace gtsam {
return combine(nrFrontals, ADT::Ring::add);
}
/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
return combine(keys, ADT::Ring::add);
}
/// Create new factor by maximizing over all values with the same separator values
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max);
@ -129,27 +135,39 @@ namespace gtsam {
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
/**
* @brief Permutes the keys in Potentials and DiscreteFactor
*
* This re-implements the permuteWithInverse() in both Potentials
* and DiscreteFactor by doing both of them together.
* Combine frontal variables in an Ordering using binary operator "op"
* @param nrFrontals nr. of frontal to combine variables in this factor
* @param op a binary operator that operates on AlgebraicDecisionDiagram potentials
* @return shared pointer to newly created DecisionTreeFactor
*/
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
void permuteWithInverse(const Permutation& inversePermutation){
DiscreteFactor::permuteWithInverse(inversePermutation);
Potentials::permuteWithInverse(inversePermutation);
}
/**
* Apply a reduction, which is a remapping of variable indices.
*/
virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
DiscreteFactor::reduceWithInverse(inverseReduction);
Potentials::reduceWithInverse(inverseReduction);
}
/** Test whether the factor is empty */
virtual bool empty() const { return size() == 0; }
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// @}
};
};
// DecisionTreeFactor
}// namespace gtsam

View File

@ -18,47 +18,53 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/BayesNetOrdered-inl.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/make_shared.hpp>
namespace gtsam {
// Explicitly instantiate so we don't have to include everywhere
template class BayesNetOrdered<DiscreteConditional> ;
// Instantiate base class
template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */
void add_front(DiscreteBayesNet& bayesNet, const Signature& s) {
bayesNet.push_front(boost::make_shared<DiscreteConditional>(s));
bool DiscreteBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
void add(DiscreteBayesNet& bayesNet, const Signature& s) {
bayesNet.push_back(boost::make_shared<DiscreteConditional>(s));
// void DiscreteBayesNet::add_front(const Signature& s) {
// push_front(boost::make_shared<DiscreteConditional>(s));
// }
/* ************************************************************************* */
void DiscreteBayesNet::add(const Signature& s) {
push_back(boost::make_shared<DiscreteConditional>(s));
}
/* ************************************************************************* */
double evaluate(const DiscreteBayesNet& bn, const DiscreteConditional::Values & values) {
double DiscreteBayesNet::evaluate(const DiscreteConditional::Values & values) const {
// evaluate all conditionals and multiply
double result = 1.0;
BOOST_FOREACH(DiscreteConditional::shared_ptr conditional, bn)
BOOST_FOREACH(DiscreteConditional::shared_ptr conditional, *this)
result *= (*conditional)(values);
return result;
}
/* ************************************************************************* */
DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn) {
DiscreteFactor::sharedValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, bn)
BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, *this)
conditional->solveInPlace(*result);
return result;
}
/* ************************************************************************* */
DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn) {
DiscreteFactor::sharedValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
DiscreteFactor::sharedValues result(new DiscreteFactor::Values());
BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, bn)
BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, *this)
conditional->sampleInPlace(*result);
return result;
}

View File

@ -20,27 +20,80 @@
#include <vector>
#include <map>
#include <boost/shared_ptr.hpp>
#include <gtsam/inference/BayesNetOrdered.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h>
namespace gtsam {
typedef BayesNetOrdered<DiscreteConditional> DiscreteBayesNet;
/** A Bayes net made from linear-Discrete densities */
class GTSAM_EXPORT DiscreteBayesNet: public FactorGraph<DiscreteConditional>
{
public:
typedef FactorGraph<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::shared_ptr<ConditionalType> sharedConditional;
/// @name Standard Constructors
/// @{
/** Construct empty factor graph */
DiscreteBayesNet() {}
/** Construct from iterator over conditionals */
template<typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteCondtional */
GTSAM_EXPORT void add(DiscreteBayesNet&, const Signature& s);
GTSAM_EXPORT void add(const Signature& s);
/** Add a DiscreteCondtional in front, when listing parents first*/
GTSAM_EXPORT void add_front(DiscreteBayesNet&, const Signature& s);
// /** Add a DiscreteCondtional in front, when listing parents first*/
// GTSAM_EXPORT void add_front(const Signature& s);
//** evaluate for given Values */
GTSAM_EXPORT double evaluate(const DiscreteBayesNet& bn, const DiscreteConditional::Values & values);
GTSAM_EXPORT double evaluate(const DiscreteConditional::Values & values) const;
/** Optimize function for back-substitution. */
GTSAM_EXPORT DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn);
/**
* Solve the DiscreteBayesNet by back-substitution
*/
DiscreteFactor::sharedValues optimize() const;
/** Do ancestral sampling */
GTSAM_EXPORT DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn);
GTSAM_EXPORT DiscreteFactor::sharedValues sample() const;
///@}
private:
/** Serialization function */
friend class boost::serialization::access;
template<class ARCHIVE>
void serialize(ARCHIVE & ar, const unsigned int version) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
} // namespace

View File

@ -0,0 +1,43 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteBayesTree.cpp
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
* @brief DiscreteBayesTree
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
namespace gtsam {
// Instantiate base class
template class BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>;
template class BayesTree<DiscreteBayesTreeClique>;
/* ************************************************************************* */
bool DiscreteBayesTree::equals(const This& other, double tol) const
{
return Base::equals(other, tol);
}
} // \namespace gtsam

View File

@ -0,0 +1,66 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteBayesTree.h
* @brief Discrete Bayes Tree, the result of eliminating a DiscreteJunctionTree
* @brief DiscreteBayesTree
* @author Frank Dellaert
* @author Richard Roberts
*/
#pragma once
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h>
namespace gtsam {
// Forward declarations
class DiscreteConditional;
class VectorValues;
/* ************************************************************************* */
/** A clique in a DiscreteBayesTree */
class GTSAM_EXPORT DiscreteBayesTreeClique :
public BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph>
{
public:
typedef DiscreteBayesTreeClique This;
typedef BayesTreeCliqueBase<DiscreteBayesTreeClique, DiscreteFactorGraph> Base;
typedef boost::shared_ptr<This> shared_ptr;
typedef boost::weak_ptr<This> weak_ptr;
DiscreteBayesTreeClique() {}
DiscreteBayesTreeClique(const boost::shared_ptr<DiscreteConditional>& conditional) : Base(conditional) {}
};
/* ************************************************************************* */
/** A Bayes tree representing a Discrete density */
class GTSAM_EXPORT DiscreteBayesTree :
public BayesTree<DiscreteBayesTreeClique>
{
private:
typedef BayesTree<DiscreteBayesTreeClique> Base;
public:
typedef DiscreteBayesTree This;
typedef boost::shared_ptr<This> shared_ptr;
/** Default constructor, creates an empty Bayes tree */
DiscreteBayesTree() {}
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;
};
}

View File

@ -18,6 +18,7 @@
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
@ -34,61 +35,72 @@ using namespace std;
namespace gtsam {
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) :
IndexConditionalOrdered(f.keys(), nrFrontals), Potentials(
f / (*f.sum(nrFrontals))) {
}
// Instantiate base class
template class Conditional<DecisionTreeFactor, DiscreteConditional> ;
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal) :
IndexConditionalOrdered(joint.keys(), joint.size() - marginal.size()), Potentials(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f) :
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal, const boost::optional<Ordering>& orderedKeys) :
BaseFactor(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
joint.size()-marginal.size()) {
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
cout << (firstFrontalKey()) << endl; //TODO Print all keys
if (orderedKeys) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys->begin(), orderedKeys->end());
}
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature) :
IndexConditionalOrdered(signature.indices(), 1), Potentials(
signature.discreteKeysParentsFirst(), signature.cpt()) {
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature) :
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
1) {
}
/* ******************************************************************************** */
void DiscreteConditional::print(const std::string& s, const IndexFormatter& formatter) const {
/* ******************************************************************************** */
void DiscreteConditional::print(const std::string& s,
const IndexFormatter& formatter) const {
std::cout << s << std::endl;
IndexConditionalOrdered::print(s, formatter);
Potentials::print(s);
}
}
/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteConditional& other, double tol) const {
return IndexConditionalOrdered::equals(other, tol)
&& Potentials::equals(other, tol);
}
/* ******************************************************************************** */
bool DiscreteConditional::equals(const DiscreteConditional& other,
double tol) const {
return Potentials::equals(other, tol);
}
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(
const Values& parentsValues) const {
/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(const Values& parentsValues) const {
ADT pFS(*this);
Index j; size_t value;
BOOST_FOREACH(Index key, parents())
try {
Index j = (key);
size_t value = parentsValues.at(j);
j = (key);
value = parentsValues.at(j);
pFS = pFS.choose(j, value);
} catch (exception&) {
throw runtime_error(
"DiscreteConditional::choose: parent value missing");
cout << "Key: " << j << " Value: " << value << endl;
pFS.print("pFS: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
return pFS;
}
}
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values& values) const {
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values& values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
print("Me: ");
values.print("values:");
cout << "nrFrontals/nrParents: " << nrFrontals() << "/" << nrParents() << endl;
cout << "first frontal: " << firstFrontalKey() << endl;
ADT pFS = choose(values); // P(F|S=parentsValues)
// Initialize
@ -97,7 +109,7 @@ namespace gtsam {
DiscreteKeys keys;
BOOST_FOREACH(Index idx, frontals()) {
DiscreteKey dk(idx,cardinality(idx));
DiscreteKey dk(idx, cardinality(idx));
keys & dk;
}
// Get all Possible Configurations
@ -117,18 +129,18 @@ namespace gtsam {
BOOST_FOREACH(Index j, frontals()) {
values[j] = mpe[j];
}
}
}
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(Values& values) const {
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(Values& values) const {
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
size_t sampled = sample(values); // Sample variable
values[j] = sampled; // store result in partial solution
}
}
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const Values& parentsValues) const {
/* ******************************************************************************** */
size_t DiscreteConditional::solve(const Values& parentsValues) const {
// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
@ -150,10 +162,10 @@ namespace gtsam {
}
}
return mpe;
}
}
/* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const {
/* ******************************************************************************** */
size_t DiscreteConditional::sample(const Values& parentsValues) const {
using boost::uniform_real;
static boost::mt19937 gen(2); // random number generator
@ -162,7 +174,8 @@ namespace gtsam {
// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
if (debug) GTSAM_PRINT(pFS);
if (debug)
GTSAM_PRINT(pFS);
// get cumulative distribution function (cdf)
// TODO, only works for one key now, seems horribly slow this way
@ -176,9 +189,11 @@ namespace gtsam {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
sum += pValueS; // accumulate
if (debug) cout << sum << " ";
if (debug)
cout << sum << " ";
if (pValueS == 1) {
if (debug) cout << "--> " << value << endl;
if (debug)
cout << "--> " << value << endl;
return value; // shortcut exit
}
cdf[value] = sum;
@ -188,20 +203,20 @@ namespace gtsam {
uniform_real<> dist(0, cdf.back());
boost::variate_generator<boost::mt19937&, uniform_real<> > die(gen, dist);
size_t sampled = lower_bound(cdf.begin(), cdf.end(), die()) - cdf.begin();
if (debug) cout << "-> " << sampled << endl;
if (debug)
cout << "-> " << sampled << endl;
return sampled;
return 0;
}
/* ******************************************************************************** */
void DiscreteConditional::permuteWithInverse(const Permutation& inversePermutation){
IndexConditionalOrdered::permuteWithInverse(inversePermutation);
Potentials::permuteWithInverse(inversePermutation);
}
}
/* ******************************************************************************** */
//void DiscreteConditional::permuteWithInverse(
// const Permutation& inversePermutation) {
// IndexConditionalOrdered::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
//}
/* ******************************************************************************** */
} // namespace
}// namespace

View File

@ -20,25 +20,28 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/IndexConditionalOrdered.h>
#include <gtsam/inference/Conditional.h>
#include <boost/shared_ptr.hpp>
#include <boost/make_shared.hpp>
namespace gtsam {
/**
/**
* Discrete Conditional Density
* Derives from DecisionTreeFactor
*/
class GTSAM_EXPORT DiscreteConditional: public IndexConditionalOrdered, public Potentials {
class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
public Conditional<DecisionTreeFactor, DiscreteConditional> {
public:
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor FactorType;
typedef boost::shared_ptr<DiscreteConditional> shared_ptr;
typedef IndexConditionalOrdered Base;
typedef DiscreteConditional This; ///< Typedef to this class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> BaseConditional; ///< Typedef to our conditional base class
/** A map from keys to values */
/** A map from keys to values..
* TODO: Again, do we need this??? */
typedef Assignment<Index> Values;
typedef boost::shared_ptr<Values> sharedValues;
@ -57,7 +60,7 @@ namespace gtsam {
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
const DecisionTreeFactor& marginal, const boost::optional<Ordering>& orderedKeys = boost::none);
/**
* Combine several conditional into a single one.
@ -67,7 +70,8 @@ namespace gtsam {
* @param lastConditional Iterator to after the last conditional to combine, must dereference to a shared_ptr<DiscreteConditional>.
* */
template<typename ITERATOR>
static shared_ptr Combine(ITERATOR firstConditional, ITERATOR lastConditional);
static shared_ptr Combine(ITERATOR firstConditional,
ITERATOR lastConditional);
/// @}
/// @name Testable
@ -121,33 +125,29 @@ namespace gtsam {
/// sample in place, stores result in partial solution
void sampleInPlace(Values& parentsValues) const;
/**
* Permutes both IndexConditional and Potentials.
*/
void permuteWithInverse(const Permutation& inversePermutation);
/// @}
};
};
// DiscreteConditional
/* ************************************************************************* */
template<typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
/* ************************************************************************* */
template<typename ITERATOR>
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
ITERATOR firstConditional, ITERATOR lastConditional) {
// TODO: check for being a clique
// multiply all the potentials of the given conditionals
size_t nrFrontals = 0;
DecisionTreeFactor product;
for(ITERATOR it = firstConditional; it != lastConditional; ++it, ++nrFrontals) {
for (ITERATOR it = firstConditional; it != lastConditional;
++it, ++nrFrontals) {
DiscreteConditional::shared_ptr c = *it;
DecisionTreeFactor::shared_ptr factor = c->toFactor();
product = (*factor) * product;
}
// and then create a new multi-frontal conditional
return boost::make_shared<DiscreteConditional>(nrFrontals,product);
}
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
}
}// gtsam
} // gtsam

View File

@ -0,0 +1,44 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteEliminationTree.cpp
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/inference/EliminationTree-inst.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
namespace gtsam {
// Instantiate base class
template class EliminationTree<DiscreteBayesNet, DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteEliminationTree::DiscreteEliminationTree(
const DiscreteFactorGraph& factorGraph, const VariableIndex& structure,
const Ordering& order) :
Base(factorGraph, structure, order) {}
/* ************************************************************************* */
DiscreteEliminationTree::DiscreteEliminationTree(
const DiscreteFactorGraph& factorGraph, const Ordering& order) :
Base(factorGraph, order) {}
/* ************************************************************************* */
bool DiscreteEliminationTree::equals(const This& other, double tol) const
{
return Base::equals(other, tol);
}
}

View File

@ -0,0 +1,63 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteEliminationTree.h
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#pragma once
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/EliminationTree.h>
namespace gtsam {
class GTSAM_EXPORT DiscreteEliminationTree :
public EliminationTree<DiscreteBayesNet, DiscreteFactorGraph>
{
public:
typedef EliminationTree<DiscreteBayesNet, DiscreteFactorGraph> Base; ///< Base class
typedef DiscreteEliminationTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/**
* Build the elimination tree of a factor graph using pre-computed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
DiscreteEliminationTree(const DiscreteFactorGraph& factorGraph,
const VariableIndex& structure, const Ordering& order);
/** Build the elimination tree of a factor graph. Note that this has to compute the column
* structure as a VariableIndex, so if you already have this precomputed, use the other
* constructor instead.
* @param factorGraph The factor graph for which to build the elimination tree
*/
DiscreteEliminationTree(const DiscreteFactorGraph& factorGraph,
const Ordering& order);
/** Test whether the tree is equal to another */
bool equals(const This& other, double tol = 1e-9) const;
private:
friend class ::EliminationTreeTester;
};
}

View File

@ -24,9 +24,5 @@ using namespace std;
namespace gtsam {
/* ******************************************************************************** */
DiscreteFactor::DiscreteFactor() {
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -19,75 +19,72 @@
#pragma once
#include <gtsam/discrete/Assignment.h>
#include <gtsam/inference/IndexFactorOrdered.h>
#include <gtsam/inference/Factor.h>
namespace gtsam {
class DecisionTreeFactor;
class DiscreteConditional;
class DecisionTreeFactor;
class DiscreteConditional;
/**
/**
* Base class for discrete probabilistic factors
* The most general one is the derived DecisionTreeFactor
*/
class GTSAM_EXPORT DiscreteFactor: public IndexFactorOrdered {
class GTSAM_EXPORT DiscreteFactor: public Factor {
public:
public:
// typedefs needed to play nice with gtsam
typedef DiscreteFactor This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<DiscreteFactor> shared_ptr;
typedef DiscreteFactor This; ///< This class
typedef boost::shared_ptr<DiscreteFactor> shared_ptr; ///< shared_ptr to this class
typedef Factor Base; ///< Our base class
/** A map from keys to values */
/** A map from keys to values
* TODO: Do we need this? Should we just use gtsam::Values?
* We just need another special DiscreteValue to represent labels,
* However, all other Lie's operators are undefined in this class.
* The good thing is we can have a Hybrid graph of discrete/continuous variables
* together..
* Another good thing is we don't need to have the special DiscreteKey which stores
* cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the varible's type (domain)
*/
typedef Assignment<Index> Values;
typedef boost::shared_ptr<Values> sharedValues;
protected:
/// Construct n-way factor
DiscreteFactor(const std::vector<Index>& js) :
IndexFactorOrdered(js) {
}
/// Construct unary factor
DiscreteFactor(Index j) :
IndexFactorOrdered(j) {
}
/// Construct binary factor
DiscreteFactor(Index j1, Index j2) :
IndexFactorOrdered(j1, j2) {
}
/// construct from container
template<class KeyIterator>
DiscreteFactor(KeyIterator beginKey, KeyIterator endKey) :
IndexFactorOrdered(beginKey, endKey) {
}
public:
public:
/// @name Standard Constructors
/// @{
/// Default constructor for I/O
DiscreteFactor();
/** Default constructor creates empty factor */
DiscreteFactor() {}
/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
template<typename CONTAINER>
DiscreteFactor(const CONTAINER& keys) : Base(keys) {}
/// Virtual destructor
virtual ~DiscreteFactor() {}
virtual ~DiscreteFactor() {
}
/// @}
/// @name Testable
/// @{
// equals
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
// print
virtual void print(const std::string& s = "DiscreteFactor\n",
const IndexFormatter& formatter
=DefaultIndexFormatter) const {
IndexFactorOrdered::print(s,formatter);
const IndexFormatter& formatter = DefaultIndexFormatter) const {
Factor::print(s, formatter);
}
/** Test whether the factor is empty */
virtual bool empty() const = 0;
/// @}
/// @name Standard Interface
/// @{
@ -100,23 +97,8 @@ namespace gtsam {
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
/**
* Permutes the factor, but for efficiency requires the permutation
* to already be inverted.
*/
virtual void permuteWithInverse(const Permutation& inversePermutation){
IndexFactorOrdered::permuteWithInverse(inversePermutation);
}
/**
* Apply a reduction, which is a remapping of variable indices.
*/
virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
IndexFactorOrdered::reduceWithInverse(inverseReduction);
}
/// @}
};
};
// DiscreteFactor
}// namespace gtsam

View File

@ -19,23 +19,23 @@
//#define ENABLE_TIMING
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/EliminationTreeOrdered-inl.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <boost/make_shared.hpp>
namespace gtsam {
// Explicitly instantiate so we don't have to include everywhere
template class FactorGraphOrdered<DiscreteFactor> ;
template class EliminationTreeOrdered<DiscreteFactor> ;
// Instantiate base classes
template class FactorGraph<DiscreteFactor>;
template class EliminateableFactorGraph<DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph() {
}
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph(
const BayesNetOrdered<DiscreteConditional>& bayesNet) :
FactorGraphOrdered<DiscreteFactor>(bayesNet) {
bool DiscreteFactorGraph::equals(const This& fg, double tol) const
{
return Base::equals(fg, tol);
}
/* ************************************************************************* */
@ -75,27 +75,34 @@ namespace gtsam {
}
}
/* ************************************************************************* */
void DiscreteFactorGraph::permuteWithInverse(
const Permutation& inversePermutation) {
BOOST_FOREACH(const sharedFactor& factor, factors_) {
if(factor)
factor->permuteWithInverse(inversePermutation);
}
}
// /* ************************************************************************* */
// void DiscreteFactorGraph::permuteWithInverse(
// const Permutation& inversePermutation) {
// BOOST_FOREACH(const sharedFactor& factor, factors_) {
// if(factor)
// factor->permuteWithInverse(inversePermutation);
// }
// }
//
// /* ************************************************************************* */
// void DiscreteFactorGraph::reduceWithInverse(
// const internal::Reduction& inverseReduction) {
// BOOST_FOREACH(const sharedFactor& factor, factors_) {
// if(factor)
// factor->reduceWithInverse(inverseReduction);
// }
// }
/* ************************************************************************* */
void DiscreteFactorGraph::reduceWithInverse(
const internal::Reduction& inverseReduction) {
BOOST_FOREACH(const sharedFactor& factor, factors_) {
if(factor)
factor->reduceWithInverse(inverseReduction);
}
DiscreteFactor::sharedValues DiscreteFactorGraph::optimize() const
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const FactorGraphOrdered<DiscreteFactor>& factors, size_t num) {
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
@ -107,18 +114,22 @@ namespace gtsam {
// sum out frontals, this is the factor on the separator
gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(num);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
// now divide product/sum to get conditional
gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
gttoc(divide);
return std::make_pair(cond, sum);
}
/* ************************************************************************* */
} // namespace

View File

@ -18,33 +18,85 @@
#pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/inference/FactorGraphOrdered.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp>
namespace gtsam {
class DiscreteFactorGraph: public FactorGraphOrdered<DiscreteFactor> {
// Forward declarations
class DiscreteFactorGraph;
class DiscreteFactor;
class DiscreteConditional;
class DiscreteBayesNet;
class DiscreteEliminationTree;
class DiscreteBayesTree;
class DiscreteJunctionTree;
/** Main elimination function for DiscreteFactorGraph */
std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& keys);
/* ************************************************************************* */
template<> struct EliminationTraits<DiscreteFactorGraph>
{
typedef DiscreteFactor FactorType; ///< Type of factors in factor graph
typedef DiscreteFactorGraph FactorGraphType; ///< Type of the factor graph (e.g. DiscreteFactorGraph)
typedef DiscreteConditional ConditionalType; ///< Type of conditionals from elimination
typedef DiscreteBayesNet BayesNetType; ///< Type of Bayes net from sequential elimination
typedef DiscreteEliminationTree EliminationTreeType; ///< Type of elimination tree
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); }
};
/* ************************************************************************* */
/**
* A Discrete Factor Graph is a factor graph where all factors are Discrete, i.e.
* Factor == DiscreteFactor
*/
class DiscreteFactorGraph: public FactorGraph<DiscreteFactor>,
public EliminateableFactorGraph<DiscreteFactorGraph> {
public:
typedef DiscreteFactorGraph This; ///< Typedef to this class
typedef FactorGraph<DiscreteFactor> Base; ///< Typedef to base factor graph type
typedef EliminateableFactorGraph<This> BaseEliminateable; ///< Typedef to base elimination class
typedef boost::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
/** A map from keys to values */
typedef std::vector<Index> Indices;
typedef Assignment<Index> Values;
typedef boost::shared_ptr<Values> sharedValues;
/** Construct empty factor graph */
GTSAM_EXPORT DiscreteFactorGraph();
/** Default constructor */
DiscreteFactorGraph() {}
/** Constructor from a factor graph of GaussianFactor or a derived type */
/** Construct from iterator over factors */
template<typename ITERATOR>
DiscreteFactorGraph(ITERATOR firstFactor, ITERATOR lastFactor) : Base(firstFactor, lastFactor) {}
/** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER>
explicit DiscreteFactorGraph(const CONTAINER& factors) : Base(factors) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */
template<class DERIVEDFACTOR>
DiscreteFactorGraph(const FactorGraphOrdered<DERIVEDFACTOR>& fg) {
push_back(fg);
}
DiscreteFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
/** construct from a BayesNet */
GTSAM_EXPORT DiscreteFactorGraph(const BayesNetOrdered<DiscreteConditional>& bayesNet);
/// @name Testable
/// @{
bool equals(const This& fg, double tol = 1e-9) const;
/// @}
template<class SOURCE>
void add(const DiscreteKey& j, SOURCE table) {
@ -80,17 +132,19 @@ public:
GTSAM_EXPORT void print(const std::string& s = "DiscreteFactorGraph",
const IndexFormatter& formatter =DefaultIndexFormatter) const;
/** Permute the variables in the factors */
GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
/** Solve the factor graph by performing variable elimination in COLAMD order using
* the dense elimination function specified in \c function,
* followed by back-substitution resulting from elimination. Is equivalent
* to calling graph.eliminateSequential()->optimize(). */
DiscreteFactor::sharedValues optimize() const;
/** Apply a reduction, which is a remapping of variable indices. */
GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
// /** Permute the variables in the factors */
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
//
// /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
};
// DiscreteFactorGraph
/** Main elimination function for DiscreteFactorGraph */
GTSAM_EXPORT std::pair<boost::shared_ptr<DiscreteConditional>, DecisionTreeFactor::shared_ptr>
EliminateDiscrete(const FactorGraphOrdered<DiscreteFactor>& factors,
size_t nrFrontals = 1);
} // namespace gtsam

View File

@ -0,0 +1,34 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteJunctionTree.cpp
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/inference/JunctionTree-inst.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
namespace gtsam {
// Instantiate base classes
template class ClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteJunctionTree::DiscreteJunctionTree(
const DiscreteEliminationTree& eliminationTree) :
Base(eliminationTree) {}
}

View File

@ -0,0 +1,66 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteJunctionTree.h
* @date Mar 29, 2013
* @author Frank Dellaert
* @author Richard Roberts
*/
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/inference/JunctionTree.h>
namespace gtsam {
// Forward declarations
class DiscreteEliminationTree;
/**
* A ClusterTree, i.e., a set of variable clusters with factors, arranged in a tree, with
* the additional property that it represents the clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k represents
* a clique (maximal fully connected subset) of an associated chordal graph, such as a
* chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors, whereas a
* BayesTree stores conditionals, that are the product of eliminating the factors in the
* corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analagous to the EliminationTree,
* except that in the JunctionTree, at each node multiple variables are eliminated at a time.
*
* \addtogroup Multifrontal
* \nosubgrouping
*/
class GTSAM_EXPORT DiscreteJunctionTree :
public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
public:
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> Base; ///< Base class
typedef DiscreteJunctionTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/**
* Build the elimination tree of a factor graph using pre-computed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};
}

View File

@ -21,7 +21,8 @@
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/GenericMultifrontalSolver.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/base/Vector.h>
namespace gtsam {
@ -32,7 +33,7 @@ namespace gtsam {
protected:
BayesTreeOrdered<DiscreteConditional> bayesTree_;
DiscreteBayesTree::shared_ptr bayesTree_;
public:
@ -40,16 +41,14 @@ namespace gtsam {
* @param graph The factor graph defining the full joint density on all variables.
*/
DiscreteMarginals(const DiscreteFactorGraph& graph) {
typedef JunctionTreeOrdered<DiscreteFactorGraph> DiscreteJT;
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
bayesTree_ = graph.eliminateMultifrontal();
}
/** Compute the marginal of a single variable */
DiscreteFactor::shared_ptr operator()(Index variable) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete);
marginalFactor = bayesTree_->marginalFactor(variable, &EliminateDiscrete);
return marginalFactor;
}
@ -60,7 +59,7 @@ namespace gtsam {
Vector marginalProbabilities(const DiscreteKey& key) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete);
marginalFactor = bayesTree_->marginalFactor(key.first, &EliminateDiscrete);
//Create result
Vector vResult(key.second);

View File

@ -1,75 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteSequentialSolver.cpp
* @date Feb 16, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
//#define ENABLE_TIMING
#include <gtsam/discrete/DiscreteSequentialSolver.h>
#include <gtsam/inference/GenericSequentialSolver-inl.h>
#include <gtsam/base/timing.h>
namespace gtsam {
template class GenericSequentialSolver<DiscreteFactor> ;
/* ************************************************************************* */
DiscreteFactor::sharedValues DiscreteSequentialSolver::optimize() const {
static const bool debug = false;
if (debug) this->factors_->print("DiscreteSequentialSolver, eliminating ");
if (debug) this->eliminationTree_->print(
"DiscreteSequentialSolver, elimination tree ");
// Eliminate using the elimination tree
gttic(eliminate);
DiscreteBayesNet::shared_ptr bayesNet = eliminate();
gttoc(eliminate);
if (debug) bayesNet->print("DiscreteSequentialSolver, Bayes net ");
// Allocate the solution vector if it is not already allocated
// Back-substitute
gttic(optimize);
DiscreteFactor::sharedValues solution = gtsam::optimize(*bayesNet);
gttoc(optimize);
if (debug) solution->print("DiscreteSequentialSolver, solution ");
return solution;
}
/* ************************************************************************* */
Vector DiscreteSequentialSolver::marginalProbabilities(
const DiscreteKey& key) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete);
//Create result
Vector vResult(key.second);
for (size_t state = 0; state < key.second; ++state) {
DiscreteFactor::Values values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}
return vResult;
}
/* ************************************************************************* */
}

View File

@ -1,113 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteSequentialSolver.h
* @date Feb 16, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/GenericSequentialSolver.h>
#include <boost/shared_ptr.hpp>
namespace gtsam {
// The base class provides all of the needed functionality
class DiscreteSequentialSolver: public GenericSequentialSolver<DiscreteFactor> {
protected:
typedef GenericSequentialSolver<DiscreteFactor> Base;
typedef boost::shared_ptr<const DiscreteSequentialSolver> shared_ptr;
public:
/**
* The problem we are trying to solve (SUM or MPE).
*/
typedef enum {
BEL, // Belief updating (or conditional updating)
MPE, // Most-Probable-Explanation
MAP
// Maximum A Posteriori hypothesis
} ProblemType;
/**
* Construct the solver for a factor graph. This builds the elimination
* tree, which already does some of the work of elimination.
*/
DiscreteSequentialSolver(const FactorGraphOrdered<DiscreteFactor>& factorGraph) :
Base(factorGraph) {
}
/**
* Construct the solver with a shared pointer to a factor graph and to a
* VariableIndex. The solver will store these pointers, so this constructor
* is the fastest.
*/
DiscreteSequentialSolver(
const FactorGraphOrdered<DiscreteFactor>::shared_ptr& factorGraph,
const VariableIndexOrdered::shared_ptr& variableIndex) :
Base(factorGraph, variableIndex) {
}
const EliminationTreeOrdered<DiscreteFactor>& eliminationTree() const {
return *eliminationTree_;
}
/**
* Eliminate the factor graph sequentially. Uses a column elimination tree
* to recursively eliminate.
*/
BayesNetOrdered<DiscreteConditional>::shared_ptr eliminate() const {
return Base::eliminate(&EliminateDiscrete);
}
/**
* Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. This function returns the result as a factor
* graph.
*/
DiscreteFactorGraph::shared_ptr jointFactorGraph(
const std::vector<Index>& js) const {
DiscreteFactorGraph::shared_ptr results(new DiscreteFactorGraph(
*Base::jointFactorGraph(js, &EliminateDiscrete)));
return results;
}
/**
* Compute the marginal density over a variable, by integrating out
* all of the other variables. This function returns the result as a factor.
*/
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
return Base::marginalFactor(j, &EliminateDiscrete);
}
/**
* Compute the marginal density over a variable, by integrating out
* all of the other variables. This function returns the result as a
* Vector of the probability values.
*/
GTSAM_EXPORT Vector marginalProbabilities(const DiscreteKey& key) const;
/**
* Compute the MPE solution of the DiscreteFactorGraph. This
* eliminates to create a BayesNet and then back-substitutes this BayesNet to
* obtain the solution.
*/
GTSAM_EXPORT DiscreteFactor::sharedValues optimize() const;
};
} // gtsam

View File

@ -59,39 +59,39 @@ namespace gtsam {
cout << endl;
ADT::print(" ");
}
/* ************************************************************************* */
template<class P>
void Potentials::remapIndices(const P& remapping) {
// Permute the _cardinalities (TODO: Inefficient Consider Improving)
DiscreteKeys keys;
map<Index, Index> ordering;
// Get the original keys from cardinalities_
BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
keys & key;
// Perform Permutation
BOOST_FOREACH(DiscreteKey& key, keys) {
ordering[key.first] = remapping[key.first];
key.first = ordering[key.first];
}
// Change *this
AlgebraicDecisionTree<Index> permuted((*this), ordering);
*this = permuted;
cardinalities_ = keys.cardinalities();
}
/* ************************************************************************* */
void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
remapIndices(inversePermutation);
}
/* ************************************************************************* */
void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
remapIndices(inverseReduction);
}
//
// /* ************************************************************************* */
// template<class P>
// void Potentials::remapIndices(const P& remapping) {
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
// DiscreteKeys keys;
// map<Index, Index> ordering;
//
// // Get the original keys from cardinalities_
// BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
// keys & key;
//
// // Perform Permutation
// BOOST_FOREACH(DiscreteKey& key, keys) {
// ordering[key.first] = remapping[key.first];
// key.first = ordering[key.first];
// }
//
// // Change *this
// AlgebraicDecisionTree<Index> permuted((*this), ordering);
// *this = permuted;
// cardinalities_ = keys.cardinalities();
// }
//
// /* ************************************************************************* */
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
// remapIndices(inversePermutation);
// }
//
// /* ************************************************************************* */
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
// remapIndices(inverseReduction);
// }
/* ************************************************************************* */

View File

@ -19,7 +19,6 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/PermutationOrdered.h>
#include <boost/shared_ptr.hpp>
#include <set>
@ -48,9 +47,9 @@ namespace gtsam {
// Safe division for probabilities
GTSAM_EXPORT static double safe_div(const double& a, const double& b);
// Apply either a permutation or a reduction
template<class P>
void remapIndices(const P& remapping);
// // Apply either a permutation or a reduction
// template<class P>
// void remapIndices(const P& remapping);
public:
@ -73,19 +72,19 @@ namespace gtsam {
size_t cardinality(Index j) const { return cardinalities_.at(j);}
/**
* @brief Permutes the keys in Potentials
*
* This permutes the Indices and performs necessary re-ordering of ADD.
* This is virtual so that derived types e.g. DecisionTreeFactor can
* re-implement it.
*/
GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
/**
* Apply a reduction, which is a remapping of variable indices.
*/
GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
// /**
// * @brief Permutes the keys in Potentials
// *
// * This permutes the Indices and performs necessary re-ordering of ADD.
// * This is virtual so that derived types e.g. DecisionTreeFactor can
// * re-implement it.
// */
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
}; // Potentials

View File

@ -12,7 +12,7 @@
/**
* @file Signature.cpp
* @brief signatures for conditional densities
* @author Frank dellaert
* @author Frank Dellaert
* @date Feb 27, 2011
*/

View File

@ -12,7 +12,7 @@
/**
* @file Signature.h
* @brief signatures for conditional densities
* @author Frank dellaert
* @author Frank Dellaert
* @date Feb 27, 2011
*/

View File

@ -17,13 +17,15 @@
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteSequentialSolver.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <CppUnitLite/TestHarness.h>
#include <boost/assign/std/map.hpp>
#include <boost/assign/list_inserter.hpp>
using namespace boost::assign;
#include <iostream>
@ -40,38 +42,40 @@ TEST(DiscreteBayesNet, Asia)
DiscreteKey A(0,2), S(4,2), T(3,2), L(6,2), B(7,2), E(5,2), X(2,2), D(1,2);
// TODO: make a version that doesn't use the parser
add_front(asia, A % "99/1");
add_front(asia, S % "50/50");
asia.add(A % "99/1");
asia.add(S % "50/50");
add_front(asia, T | A = "99/1 95/5");
add_front(asia, L | S = "99/1 90/10");
add_front(asia, B | S = "70/30 40/60");
asia.add(T | A = "99/1 95/5");
asia.add(L | S = "99/1 90/10");
asia.add(B | S = "70/30 40/60");
add_front(asia, (E | T, L) = "F T T T");
asia.add((E | T, L) = "F T T T");
add_front(asia, X | E = "95/5 2/98");
// next lines are same as add_front(asia, (D | E, B) = "9/1 2/8 3/7 1/9");
asia.add(X | E = "95/5 2/98");
// next lines are same as asia.add((D | E, B) = "9/1 2/8 3/7 1/9");
DiscreteConditional::shared_ptr actual =
boost::make_shared<DiscreteConditional>((D | E, B) = "9/1 2/8 3/7 1/9");
asia.push_front(actual);
asia.push_back(actual);
// GTSAM_PRINT(asia);
// Convert to factor graph
DiscreteFactorGraph fg(asia);
// GTSAM_PRINT(fg);
LONGS_EQUAL(3,fg.front()->size());
// GTSAM_PRINT(fg);
LONGS_EQUAL(3,fg.back()->size());
Potentials::ADT expected(B & D & E, "0.9 0.3 0.1 0.7 0.2 0.1 0.8 0.9");
CHECK(assert_equal(expected,(Potentials::ADT)*actual));
// Create solver and eliminate
DiscreteSequentialSolver solver(fg);
DiscreteBayesNet::shared_ptr chordal = solver.eliminate();
// GTSAM_PRINT(*chordal);
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4),Key(5),Key(6),Key(7);
ordering.print("ordering: ");
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// GTSAM_PRINT(*chordal);
DiscreteConditional expected2(B % "11/9");
CHECK(assert_equal(expected2,*chordal->back()));
// solve
DiscreteFactor::sharedValues actualMPE = optimize(*chordal);
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(A.first, 0)(D.first, 0)(X.first, 0)(T.first, 0)(S.first,
0)(E.first, 0)(L.first, 0)(B.first, 0);
@ -83,10 +87,10 @@ TEST(DiscreteBayesNet, Asia)
// fg.product().dot("fg");
// solve again, now with evidence
DiscreteSequentialSolver solver2(fg);
DiscreteBayesNet::shared_ptr chordal2 = solver2.eliminate();
// GTSAM_PRINT(*chordal2);
DiscreteFactor::sharedValues actualMPE2 = optimize(*chordal2);
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential();
GTSAM_PRINT(*chordal2);
DiscreteFactor::sharedValues actualMPE2 = chordal2->optimize();
cout << "optimized!" << endl;
DiscreteFactor::Values expectedMPE2;
insert(expectedMPE2)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(S.first,
1)(E.first, 0)(L.first, 0)(B.first, 1);
@ -97,8 +101,8 @@ TEST(DiscreteBayesNet, Asia)
SETDEBUG("DiscreteConditional::sample", false);
insert(expectedSample)(A.first, 1)(D.first, 1)(X.first, 0)(T.first, 0)(
S.first, 1)(E.first, 0)(L.first, 0)(B.first, 1);
DiscreteFactor::sharedValues actualSample = sample(*chordal2);
EXPECT(assert_equal(expectedSample, *actualSample));
DiscreteFactor::sharedValues actualSample = chordal->sample();
// EXPECT(assert_equal(expectedSample, *actualSample));
}
/* ************************************************************************* */
@ -114,12 +118,12 @@ TEST_UNSAFE(DiscreteBayesNet, Sugar)
// add(bn, D | E = "blah");
// try logic
add(bn, (E | T, L) = "OR");
add(bn, (E | T, L) = "AND");
bn.add((E | T, L) = "OR");
bn.add((E | T, L) = "AND");
// // try multivalued
add(bn, C % "1/1/2");
add(bn, C | S = "1/1/2 5/2/3");
bn.add(C % "1/1/2");
bn.add(C | S = "1/1/2 5/2/3");
}
/* ************************************************************************* */

View File

@ -1,261 +1,261 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* @file testDiscreteBayesTree.cpp
* @date sept 15, 2012
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/BayesTreeOrdered.h>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
///* ----------------------------------------------------------------------------
//
// * GTSAM Copyright 2010, Georgia Tech Research Corporation,
// * Atlanta, Georgia 30332-0415
// * All Rights Reserved
// * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
//
// * See LICENSE for the license information
//
// * -------------------------------------------------------------------------- */
//
///*
// * @file testDiscreteBayesTree.cpp
// * @date sept 15, 2012
// * @author Frank Dellaert
// */
//
//#include <gtsam/discrete/DiscreteBayesNet.h>
//#include <gtsam/discrete/DiscreteBayesTree.h>
//#include <gtsam/discrete/DiscreteFactorGraph.h>
//
//#include <boost/assign/std/vector.hpp>
//using namespace boost::assign;
//
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
static bool debug = false;
/**
* Custom clique class to debug shortcuts
*/
class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
protected:
public:
typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
typedef boost::shared_ptr<Clique> shared_ptr;
// Constructors
Clique() {
}
Clique(const DiscreteConditional::shared_ptr& conditional) :
Base(conditional) {
}
Clique(
const std::pair<DiscreteConditional::shared_ptr,
DiscreteConditional::FactorType::shared_ptr>& result) :
Base(result) {
}
/// print index signature only
void printSignature(const std::string& s = "Clique: ",
const IndexFormatter& indexFormatter = DefaultIndexFormatter) const {
((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
}
/// evaluate value of sub-tree
double evaluate(const DiscreteConditional::Values & values) {
double result = (*(this->conditional_))(values);
// evaluate all children and multiply into result
BOOST_FOREACH(boost::shared_ptr<Clique> c, children_)
result *= c->evaluate(values);
return result;
}
};
typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
/* ************************************************************************* */
double evaluate(const DiscreteBayesTree& tree,
const DiscreteConditional::Values & values) {
return tree.root()->evaluate(values);
}
/* ************************************************************************* */
TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
const int nrNodes = 15;
const size_t nrStates = 2;
// define variables
vector<DiscreteKey> key;
for (int i = 0; i < nrNodes; i++) {
DiscreteKey key_i(i, nrStates);
key.push_back(key_i);
}
// create a thin-tree Bayesnet, a la Jean-Guillaume
DiscreteBayesNet bayesNet;
add_front(bayesNet, key[14] % "1/3");
add_front(bayesNet, key[13] | key[14] = "1/3 3/1");
add_front(bayesNet, key[12] | key[14] = "3/1 3/1");
add_front(bayesNet, (key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
add_front(bayesNet, (key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
add_front(bayesNet, (key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
add_front(bayesNet, (key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
add_front(bayesNet, (key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
add_front(bayesNet, (key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
add_front(bayesNet, (key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
add_front(bayesNet, (key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
add_front(bayesNet, (key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
add_front(bayesNet, (key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
if (debug) {
GTSAM_PRINT(bayesNet);
bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
}
// create a BayesTree out of a Bayes net
DiscreteBayesTree bayesTree(bayesNet);
if (debug) {
GTSAM_PRINT(bayesTree);
bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
}
// Check whether BN and BT give the same answer on all configurations
// Also calculate all some marginals
Vector marginals = zero(15);
double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
joint_4_11 = 0;
vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
& key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
for (size_t i = 0; i < allPosbValues.size(); ++i) {
DiscreteFactor::Values x = allPosbValues[i];
double expected = evaluate(bayesNet, x);
double actual = evaluate(bayesTree, x);
DOUBLES_EQUAL(expected, actual, 1e-9);
// collect marginals
for (size_t i = 0; i < 15; i++)
if (x[i])
marginals[i] += actual;
// calculate shortcut 8 and 0
if (x[12] && x[14])
joint_12_14 += actual;
if (x[9] && x[12] & x[14])
joint_9_12_14 += actual;
if (x[8] && x[12] & x[14])
joint_8_12_14 += actual;
if (x[8] && x[12])
joint_8_12 += actual;
if (x[8] && x[2])
joint82 += actual;
if (x[1] && x[2])
joint12 += actual;
if (x[2] && x[4])
joint24 += actual;
if (x[4] && x[5])
joint45 += actual;
if (x[4] && x[6])
joint46 += actual;
if (x[4] && x[11])
joint_4_11 += actual;
}
DiscreteFactor::Values all1 = allPosbValues.back();
Clique::shared_ptr R = bayesTree.root();
// check separator marginal P(S0)
Clique::shared_ptr c = bayesTree[0];
DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
// check separator marginal P(S9), should be P(14)
c = bayesTree[9];
DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
// check separator marginal of root, should be empty
c = bayesTree[11];
DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
// check shortcut P(S9||R) to root
c = bayesTree[9];
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_LONGS_EQUAL(0, shortcut.size());
// check shortcut P(S8||R) to root
c = bayesTree[8];
shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
1e-9);
// check shortcut P(S2||R) to root
c = bayesTree[2];
shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
1e-9);
// check shortcut P(S0||R) to root
c = bayesTree[0];
shortcut = c->shortcut(R, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
1e-9);
// calculate all shortcuts to root
DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
BOOST_FOREACH(Clique::shared_ptr c, cliques) {
DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
if (debug) {
c->printSignature();
shortcut.print("shortcut:");
}
}
// Check all marginals
DiscreteFactor::shared_ptr marginalFactor;
for (size_t i = 0; i < 15; i++) {
marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
double actual = (*marginalFactor)(all1);
EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
}
DiscreteBayesNet::shared_ptr actualJoint;
// Check joint P(8,2) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(1,2) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(2,4)
actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(4,5) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(4,6) TODO: not disjoint !
// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
// Check joint P(4,11)
actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
}
//
//using namespace std;
//using namespace gtsam;
//
//static bool debug = false;
//
///**
// * Custom clique class to debug shortcuts
// */
////class Clique: public BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> {
////
////protected:
////
////public:
////
//// typedef BayesTreeCliqueBaseOrdered<Clique, DiscreteConditional> Base;
//// typedef boost::shared_ptr<Clique> shared_ptr;
////
//// // Constructors
//// Clique() {
//// }
//// Clique(const DiscreteConditional::shared_ptr& conditional) :
//// Base(conditional) {
//// }
//// Clique(
//// const std::pair<DiscreteConditional::shared_ptr,
//// DiscreteConditional::FactorType::shared_ptr>& result) :
//// Base(result) {
//// }
////
//// /// print index signature only
//// void printSignature(const std::string& s = "Clique: ",
//// const IndexFormatter& indexFormatter = DefaultIndexFormatter) const {
//// ((IndexConditionalOrdered::shared_ptr) conditional_)->print(s, indexFormatter);
//// }
////
//// /// evaluate value of sub-tree
//// double evaluate(const DiscreteConditional::Values & values) {
//// double result = (*(this->conditional_))(values);
//// // evaluate all children and multiply into result
//// BOOST_FOREACH(boost::shared_ptr<Clique> c, children_)
//// result *= c->evaluate(values);
//// return result;
//// }
////
////};
//
////typedef BayesTreeOrdered<DiscreteConditional, Clique> DiscreteBayesTree;
////
/////* ************************************************************************* */
////double evaluate(const DiscreteBayesTree& tree,
//// const DiscreteConditional::Values & values) {
//// return tree.root()->evaluate(values);
////}
//
///* ************************************************************************* */
//
//TEST_UNSAFE( DiscreteBayesTree, thinTree ) {
//
// const int nrNodes = 15;
// const size_t nrStates = 2;
//
// // define variables
// vector<DiscreteKey> key;
// for (int i = 0; i < nrNodes; i++) {
// DiscreteKey key_i(i, nrStates);
// key.push_back(key_i);
// }
//
// // create a thin-tree Bayesnet, a la Jean-Guillaume
// DiscreteBayesNet bayesNet;
// bayesNet.add(key[14] % "1/3");
//
// bayesNet.add(key[13] | key[14] = "1/3 3/1");
// bayesNet.add(key[12] | key[14] = "3/1 3/1");
//
// bayesNet.add((key[11] | key[13], key[14]) = "1/4 2/3 3/2 4/1");
// bayesNet.add((key[10] | key[13], key[14]) = "1/4 3/2 2/3 4/1");
// bayesNet.add((key[9] | key[12], key[14]) = "4/1 2/3 F 1/4");
// bayesNet.add((key[8] | key[12], key[14]) = "T 1/4 3/2 4/1");
//
// bayesNet.add((key[7] | key[11], key[13]) = "1/4 2/3 3/2 4/1");
// bayesNet.add((key[6] | key[11], key[13]) = "1/4 3/2 2/3 4/1");
// bayesNet.add((key[5] | key[10], key[13]) = "4/1 2/3 3/2 1/4");
// bayesNet.add((key[4] | key[10], key[13]) = "2/3 1/4 3/2 4/1");
//
// bayesNet.add((key[3] | key[9], key[12]) = "1/4 2/3 3/2 4/1");
// bayesNet.add((key[2] | key[9], key[12]) = "1/4 8/2 2/3 4/1");
// bayesNet.add((key[1] | key[8], key[12]) = "4/1 2/3 3/2 1/4");
// bayesNet.add((key[0] | key[8], key[12]) = "2/3 1/4 3/2 4/1");
//
//// if (debug) {
//// GTSAM_PRINT(bayesNet);
//// bayesNet.saveGraph("/tmp/discreteBayesNet.dot");
//// }
//
// // create a BayesTree out of a Bayes net
// DiscreteBayesTree bayesTree(bayesNet);
// if (debug) {
// GTSAM_PRINT(bayesTree);
// bayesTree.saveGraph("/tmp/discreteBayesTree.dot");
// }
//
// // Check whether BN and BT give the same answer on all configurations
// // Also calculate all some marginals
// Vector marginals = zero(15);
// double joint_12_14 = 0, joint_9_12_14 = 0, joint_8_12_14 = 0, joint_8_12 = 0,
// joint82 = 0, joint12 = 0, joint24 = 0, joint45 = 0, joint46 = 0,
// joint_4_11 = 0;
// vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
// key[0] & key[1] & key[2] & key[3] & key[4] & key[5] & key[6] & key[7]
// & key[8] & key[9] & key[10] & key[11] & key[12] & key[13] & key[14]);
// for (size_t i = 0; i < allPosbValues.size(); ++i) {
// DiscreteFactor::Values x = allPosbValues[i];
// double expected = evaluate(bayesNet, x);
// double actual = evaluate(bayesTree, x);
// DOUBLES_EQUAL(expected, actual, 1e-9);
// // collect marginals
// for (size_t i = 0; i < 15; i++)
// if (x[i])
// marginals[i] += actual;
// // calculate shortcut 8 and 0
// if (x[12] && x[14])
// joint_12_14 += actual;
// if (x[9] && x[12] & x[14])
// joint_9_12_14 += actual;
// if (x[8] && x[12] & x[14])
// joint_8_12_14 += actual;
// if (x[8] && x[12])
// joint_8_12 += actual;
// if (x[8] && x[2])
// joint82 += actual;
// if (x[1] && x[2])
// joint12 += actual;
// if (x[2] && x[4])
// joint24 += actual;
// if (x[4] && x[5])
// joint45 += actual;
// if (x[4] && x[6])
// joint46 += actual;
// if (x[4] && x[11])
// joint_4_11 += actual;
// }
// DiscreteFactor::Values all1 = allPosbValues.back();
//
// Clique::shared_ptr R = bayesTree.root();
//
// // check separator marginal P(S0)
// Clique::shared_ptr c = bayesTree[0];
// DiscreteFactorGraph separatorMarginal0 = c->separatorMarginal(R,
// EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_8_12, separatorMarginal0(all1), 1e-9);
//
// // check separator marginal P(S9), should be P(14)
// c = bayesTree[9];
// DiscreteFactorGraph separatorMarginal9 = c->separatorMarginal(R,
// EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(marginals[14], separatorMarginal9(all1), 1e-9);
//
// // check separator marginal of root, should be empty
// c = bayesTree[11];
// DiscreteFactorGraph separatorMarginal11 = c->separatorMarginal(R,
// EliminateDiscrete);
// EXPECT_LONGS_EQUAL(0, separatorMarginal11.size());
//
// // check shortcut P(S9||R) to root
// c = bayesTree[9];
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
// EXPECT_LONGS_EQUAL(0, shortcut.size());
//
// // check shortcut P(S8||R) to root
// c = bayesTree[8];
// shortcut = c->shortcut(R, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_12_14/marginals[14], evaluate(shortcut,all1),
// 1e-9);
//
// // check shortcut P(S2||R) to root
// c = bayesTree[2];
// shortcut = c->shortcut(R, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_9_12_14/marginals[14], evaluate(shortcut,all1),
// 1e-9);
//
// // check shortcut P(S0||R) to root
// c = bayesTree[0];
// shortcut = c->shortcut(R, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_8_12_14/marginals[14], evaluate(shortcut,all1),
// 1e-9);
//
// // calculate all shortcuts to root
// DiscreteBayesTree::Nodes cliques = bayesTree.nodes();
// BOOST_FOREACH(Clique::shared_ptr c, cliques) {
// DiscreteBayesNet shortcut = c->shortcut(R, EliminateDiscrete);
// if (debug) {
// c->printSignature();
// shortcut.print("shortcut:");
// }
// }
//
// // Check all marginals
// DiscreteFactor::shared_ptr marginalFactor;
// for (size_t i = 0; i < 15; i++) {
// marginalFactor = bayesTree.marginalFactor(i, EliminateDiscrete);
// double actual = (*marginalFactor)(all1);
// EXPECT_DOUBLES_EQUAL(marginals[i], actual, 1e-9);
// }
//
// DiscreteBayesNet::shared_ptr actualJoint;
//
// // Check joint P(8,2) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(8, 2, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint82, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(1,2) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(1, 2, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint12, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(2,4)
// actualJoint = bayesTree.jointBayesNet(2, 4, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint24, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,5) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(4, 5, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,6) TODO: not disjoint !
//// actualJoint = bayesTree.jointBayesNet(4, 6, EliminateDiscrete);
//// EXPECT_DOUBLES_EQUAL(joint46, evaluate(*actualJoint,all1), 1e-9);
//
// // Check joint P(4,11)
// actualJoint = bayesTree.jointBayesNet(4, 11, EliminateDiscrete);
// EXPECT_DOUBLES_EQUAL(joint_4_11, evaluate(*actualJoint,all1), 1e-9);
//
//}
/* ************************************************************************* */
int main() {

View File

@ -17,8 +17,8 @@
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteSequentialSolver.h>
#include <gtsam/inference/GenericMultifrontalSolver.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <CppUnitLite/TestHarness.h>
@ -58,9 +58,9 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
DiscreteSequentialSolver solver(graph);
// solver.print("solver:");
EliminationTreeOrdered<DiscreteFactor> eliminationTree = solver.eliminationTree();
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
@ -127,10 +127,11 @@ TEST( DiscreteFactorGraph, test)
// Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys;
frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) =//
EliminateDiscrete(graph, 1);
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net
CHECK(conditional);
@ -139,34 +140,35 @@ TEST( DiscreteFactorGraph, test)
// cout << signature << endl;
DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional));
add(expected, signature);
expected.add(signature);
// Check Factor
CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor));
// Create solver
DiscreteSequentialSolver solver(graph);
// add conditionals to complete expected Bayes net
add(expected, B | A = "5/3 3/5");
add(expected, A % "1/1");
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
const EliminationTreeOrdered<DiscreteFactor>& etree = solver.eliminationTree();
DiscreteBayesNet::shared_ptr actual = etree.eliminate(&EliminateDiscrete);
Ordering ordering;
ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// Test solver
DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
EXPECT(assert_equal(expected, *actual2));
// // Test solver
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
// EXPECT(assert_equal(expected, *actual2));
// Test optimization
DiscreteFactor::Values expectedValues;
insert(expectedValues)(0, 0)(1, 0)(2, 0);
DiscreteFactor::sharedValues actualValues = solver.optimize();
DiscreteFactor::sharedValues actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, *actualValues));
}
@ -183,8 +185,7 @@ TEST( DiscreteFactorGraph, testMPE)
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteFactor::sharedValues actualMPE =
DiscreteSequentialSolver(graph).optimize();
DiscreteFactor::sharedValues actualMPE = graph.optimize();
DiscreteFactor::Values expectedMPE;
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
@ -213,18 +214,23 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
// Use the solver machinery.
DiscreteBayesNet::shared_ptr chordal =
DiscreteSequentialSolver(graph).eliminate();
DiscreteFactor::sharedValues actualMPE = optimize(*chordal);
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
DiscreteFactor::sharedValues actualMPE = chordal->optimize();
EXPECT(assert_equal(expectedMPE, *actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
// bayesTree->print("Bayes Tree");
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2,bayesTree->size());
#ifdef OLD

View File

@ -18,7 +18,6 @@
*/
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/discrete/DiscreteSequentialSolver.h>
#include <boost/assign/std/vector.hpp>
using namespace boost::assign;
@ -119,11 +118,9 @@ TEST_UNSAFE( DiscreteMarginals, truss ) {
graph.add(key[0] & key[2] & key[4],"2 2 2 2 1 1 1 1");
graph.add(key[1] & key[3] & key[4],"1 1 1 1 2 2 2 2");
graph.add(key[2] & key[3] & key[4],"1 1 1 1 1 1 1 1");
typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal();
// bayesTree->print("Bayes Tree");
typedef BayesTreeClique<DiscreteConditional> Clique;
typedef DiscreteBayesTreeClique Clique;
Clique expected0(boost::make_shared<DiscreteConditional>((key[0] | key[2], key[4]) = "2/1 2/1 2/1 2/1"));
Clique::shared_ptr actual0 = (*bayesTree)[0];
@ -180,16 +177,16 @@ TEST_UNSAFE( DiscreteMarginals, truss2 ) {
}
// Check all marginals given by a sequential solver and Marginals
DiscreteSequentialSolver solver(graph);
// DiscreteSequentialSolver solver(graph);
DiscreteMarginals marginals(graph);
for (size_t j=0;j<5;j++) {
double sum = T[j]+F[j];
T[j]/=sum;
F[j]/=sum;
// solver
Vector actualV = solver.marginalProbabilities(key[j]);
EXPECT(assert_equal(Vector_(2,F[j],T[j]), actualV));
// // solver
// Vector actualV = solver.marginalProbabilities(key[j]);
// EXPECT(assert_equal(Vector_(2,F[j],T[j]), actualV));
// Marginals
vector<double> table;