Merging all Discrete changes

release/4.3a0
Abhijit Kundu 2012-06-06 03:10:13 +00:00
parent e6a0663540
commit c31ef559f8
10 changed files with 203 additions and 80 deletions

View File

@ -127,6 +127,11 @@ namespace gtsam {
*/
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
void permuteWithInverse(const Permutation& inversePermutation){
DiscreteFactor::permuteWithInverse(inversePermutation);
Potentials::permute(inversePermutation);
}
/// @}
};
// DecisionTreeFactor

View File

@ -46,9 +46,9 @@ namespace gtsam {
const DecisionTreeFactor& marginal) :
IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials(
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
assert(nrFrontals() == 1);
// assert(nrFrontals() == 1);
if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout
<< (firstFrontalKey()) << endl;
<< (firstFrontalKey()) << endl; //TODO Print all keys
}
/* ******************************************************************************** */
@ -75,10 +75,47 @@ namespace gtsam {
/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(Values& values) const {
assert(nrFrontals() == 1);
Index j = (firstFrontalKey());
size_t mpe = solve(values); // Solve for variable
values[j] = mpe; // store result in partial solution
// assert(nrFrontals() == 1);
// Index j = (firstFrontalKey());
// size_t mpe = solve(values); // Solve for variable
// values[j] = mpe; // store result in partial solution
// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(values); // P(F|S=parentsValues)
// Initialize
Values mpe;
Values frontalVals;
BOOST_FOREACH(Index j, frontals()) {
frontalVals[j] = 0;
}
double maxP = 0;
while (1) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
size_t j = 0;
for (j = 0; j < nrFrontals(); j++) {
Index idx = frontals()[j];
frontalVals[idx]++;
if (frontalVals[idx] < cardinality(idx))
break;
//Wrap condition
frontalVals[idx] = 0;
}
if (j == nrFrontals())
break;
}
//set values (inPlace) to mpe
BOOST_FOREACH(Index j, frontals()) {
values[j] = mpe[j];
}
}
/* ******************************************************************************** */

View File

@ -98,6 +98,14 @@ namespace gtsam {
virtual operator DecisionTreeFactor() const = 0;
/**
* Permutes the factor, but for efficiency requires the permutation
* to already be inverted.
*/
virtual void permuteWithInverse(const Permutation& inversePermutation){
IndexFactor::permuteWithInverse(inversePermutation);
}
/// @}
};
// DiscreteFactor

View File

@ -24,54 +24,56 @@
namespace gtsam {
// Explicitly instantiate so we don't have to include everywhere
template class FactorGraph<DiscreteFactor> ;
template class EliminationTree<DiscreteFactor> ;
// Explicitly instantiate so we don't have to include everywhere
template class FactorGraph<DiscreteFactor> ;
template class EliminationTree<DiscreteFactor> ;
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph() {
}
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph() {
}
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph(
/* ************************************************************************* */
DiscreteFactorGraph::DiscreteFactorGraph(
const BayesNet<DiscreteConditional>& bayesNet) :
FactorGraph<DiscreteFactor>(bayesNet) {
}
}
/* ************************************************************************* */
FastSet<Index> DiscreteFactorGraph::keys() const {
/* ************************************************************************* */
FastSet<Index> DiscreteFactorGraph::keys() const {
FastSet<Index> keys;
BOOST_FOREACH(const sharedFactor& factor, *this)
if (factor) keys.insert(factor->begin(), factor->end());
return keys;
}
}
/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
BOOST_FOREACH(const sharedFactor& factor, *this)
if (factor) result = (*factor) * result;
return result;
}
}
/* ************************************************************************* */
double DiscreteFactorGraph::operator()(
/* ************************************************************************* */
double DiscreteFactorGraph::operator()(
const DiscreteFactor::Values &values) const {
double product = 1.0;
BOOST_FOREACH( const sharedFactor& factor, factors_ )
product *= (*factor)(values);
return product;
}
}
/* ************************************************************************* */
pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
/* ************************************************************************* */
pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
// PRODUCT: multiply all factors
tic(1, "product");
DecisionTreeFactor product;
BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors)
BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors){
product = (*factor) * product;
}
toc(1, "product");
// sum out frontals, this is the factor on the separator
@ -86,7 +88,7 @@ EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
tictoc_finishedIteration();
return make_pair(cond, sum);
}
}
/* ************************************************************************* */
} // namespace

View File

@ -86,7 +86,6 @@ public:
if (factors_[i] != NULL) factors_[i]->print(ss.str());
}
}
};
// DiscreteFactorGraph

View File

@ -21,6 +21,8 @@
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/GenericMultifrontalSolver.h>
#include <gtsam/linear/GaussianBayesTree.h>
namespace gtsam {
@ -39,22 +41,38 @@ namespace gtsam {
* @param graph The factor graph defining the full joint density on all variables.
*/
DiscreteMarginals(const DiscreteFactorGraph& graph) {
}
/** print */
void print(const std::string& str = "DiscreteMarginals: ") const {
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
//bayesTree_.print();
cout << "Discrete Marginal Constructor Finished" << endl;
}
/** Compute the marginal of a single variable */
DiscreteFactor::shared_ptr operator()(Index variable) const {
DiscreteFactor::shared_ptr p;
return p;
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete);
return marginalFactor;
}
/** Compute the marginal of a single variable */
Vector marginalProbabilities(Index variable) const {
Vector v;
return v;
/** Compute the marginal of a single variable
* @param KEY DiscreteKey of the Variable
* @return Vector of marginal probabilities
*/
Vector marginalProbabilities(const DiscreteKey& key) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete);
//Create result
Vector vResult(key.second);
for (size_t state = 0; state < key.second ; ++ state) {
DiscreteFactor::Values values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}
return vResult;
}
};

View File

@ -52,6 +52,24 @@ namespace gtsam {
return solution;
}
/* ************************************************************************* */
Vector DiscreteSequentialSolver::marginalProbabilities(
const DiscreteKey& key) const {
// Compute marginal
DiscreteFactor::shared_ptr marginalFactor;
marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete);
//Create result
Vector vResult(key.second);
for (size_t state = 0; state < key.second; ++state) {
DiscreteFactor::Values values;
values[key.first] = state;
vResult(state) = (*marginalFactor)(values);
}
return vResult;
}
/* ************************************************************************* */
}

View File

@ -74,7 +74,6 @@ namespace gtsam {
return Base::eliminate(&EliminateDiscrete);
}
#ifdef BROKEN
/**
* Compute the marginal joint over a set of variables, by integrating out
* all of the other variables. This function returns the result as a factor
@ -94,7 +93,13 @@ namespace gtsam {
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
return Base::marginalFactor(j, &EliminateDiscrete);
}
#endif
/**
* Compute the marginal density over a variable, by integrating out
* all of the other variables. This function returns the result as a
* Vector of the probability values.
*/
Vector marginalProbabilities(const DiscreteKey& key) const;
/**
* Compute the MPE solution of the DiscreteFactorGraph. This

View File

@ -60,5 +60,28 @@ namespace gtsam {
}
/* ************************************************************************* */
void Potentials::permute(const Permutation& permutation) {
// Permute the _cardinalities (TODO: Inefficient Consider Improving)
DiscreteKeys keys;
map<Index, Index> ordering;
// Get the orginal keys from cardinalities_
BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
keys & key;
// Perform Permutation
BOOST_FOREACH(DiscreteKey& key, keys) {
ordering[key.first] = permutation[key.first];
//cout << key.first << " -> " << ordering[key.first] << endl;
key.first = ordering[key.first];
}
// Change *this
AlgebraicDecisionTree<Index> permuted((*this), ordering);
*this = permuted;
cardinalities_ = keys.cardinalities();
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -20,6 +20,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/base/types.h>
#include <gtsam/inference/Permutation.h>
#include <boost/shared_ptr.hpp>
#include <set>
@ -68,6 +69,13 @@ namespace gtsam {
size_t cardinality(Index j) const { return cardinalities_.at(j);}
/*
* @brief Permutes the keys in Potentials
*
* This permutes the Indices and performs necessary re-ordering of ADD.
*/
void permute(const Permutation& perm);
}; // Potentials
} // namespace gtsam