Merging all Discrete changes
parent
e6a0663540
commit
c31ef559f8
|
@ -127,6 +127,11 @@ namespace gtsam {
|
|||
*/
|
||||
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
||||
|
||||
void permuteWithInverse(const Permutation& inversePermutation){
|
||||
DiscreteFactor::permuteWithInverse(inversePermutation);
|
||||
Potentials::permute(inversePermutation);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
// DecisionTreeFactor
|
||||
|
|
|
@ -46,9 +46,9 @@ namespace gtsam {
|
|||
const DecisionTreeFactor& marginal) :
|
||||
IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials(
|
||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
|
||||
assert(nrFrontals() == 1);
|
||||
// assert(nrFrontals() == 1);
|
||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout
|
||||
<< (firstFrontalKey()) << endl;
|
||||
<< (firstFrontalKey()) << endl; //TODO Print all keys
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
|
@ -75,10 +75,47 @@ namespace gtsam {
|
|||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(Values& values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Index j = (firstFrontalKey());
|
||||
size_t mpe = solve(values); // Solve for variable
|
||||
values[j] = mpe; // store result in partial solution
|
||||
// assert(nrFrontals() == 1);
|
||||
// Index j = (firstFrontalKey());
|
||||
// size_t mpe = solve(values); // Solve for variable
|
||||
// values[j] = mpe; // store result in partial solution
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = choose(values); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
Values mpe;
|
||||
Values frontalVals;
|
||||
BOOST_FOREACH(Index j, frontals()) {
|
||||
frontalVals[j] = 0;
|
||||
}
|
||||
double maxP = 0;
|
||||
|
||||
while (1) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
|
||||
size_t j = 0;
|
||||
for (j = 0; j < nrFrontals(); j++) {
|
||||
Index idx = frontals()[j];
|
||||
frontalVals[idx]++;
|
||||
if (frontalVals[idx] < cardinality(idx))
|
||||
break;
|
||||
//Wrap condition
|
||||
frontalVals[idx] = 0;
|
||||
}
|
||||
if (j == nrFrontals())
|
||||
break;
|
||||
}
|
||||
|
||||
//set values (inPlace) to mpe
|
||||
BOOST_FOREACH(Index j, frontals()) {
|
||||
values[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
|
|
|
@ -98,6 +98,14 @@ namespace gtsam {
|
|||
|
||||
virtual operator DecisionTreeFactor() const = 0;
|
||||
|
||||
/**
|
||||
* Permutes the factor, but for efficiency requires the permutation
|
||||
* to already be inverted.
|
||||
*/
|
||||
virtual void permuteWithInverse(const Permutation& inversePermutation){
|
||||
IndexFactor::permuteWithInverse(inversePermutation);
|
||||
}
|
||||
|
||||
/// @}
|
||||
};
|
||||
// DiscreteFactor
|
||||
|
|
|
@ -24,69 +24,71 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
// Explicitly instantiate so we don't have to include everywhere
|
||||
template class FactorGraph<DiscreteFactor> ;
|
||||
template class EliminationTree<DiscreteFactor> ;
|
||||
// Explicitly instantiate so we don't have to include everywhere
|
||||
template class FactorGraph<DiscreteFactor> ;
|
||||
template class EliminationTree<DiscreteFactor> ;
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactorGraph::DiscreteFactorGraph() {
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactorGraph::DiscreteFactorGraph() {
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactorGraph::DiscreteFactorGraph(
|
||||
const BayesNet<DiscreteConditional>& bayesNet) :
|
||||
FactorGraph<DiscreteFactor>(bayesNet) {
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
DiscreteFactorGraph::DiscreteFactorGraph(
|
||||
const BayesNet<DiscreteConditional>& bayesNet) :
|
||||
FactorGraph<DiscreteFactor>(bayesNet) {
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
FastSet<Index> DiscreteFactorGraph::keys() const {
|
||||
FastSet<Index> keys;
|
||||
BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
return keys;
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
FastSet<Index> DiscreteFactorGraph::keys() const {
|
||||
FastSet<Index> keys;
|
||||
BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||
if (factor) keys.insert(factor->begin(), factor->end());
|
||||
return keys;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||
if (factor) result = (*factor) * result;
|
||||
return result;
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||
DecisionTreeFactor result;
|
||||
BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||
if (factor) result = (*factor) * result;
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactorGraph::operator()(
|
||||
const DiscreteFactor::Values &values) const {
|
||||
double product = 1.0;
|
||||
BOOST_FOREACH( const sharedFactor& factor, factors_ )
|
||||
product *= (*factor)(values);
|
||||
return product;
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
double DiscreteFactorGraph::operator()(
|
||||
const DiscreteFactor::Values &values) const {
|
||||
double product = 1.0;
|
||||
BOOST_FOREACH( const sharedFactor& factor, factors_ )
|
||||
product *= (*factor)(values);
|
||||
return product;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
|
||||
/* ************************************************************************* */
|
||||
pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
|
||||
|
||||
// PRODUCT: multiply all factors
|
||||
tic(1, "product");
|
||||
DecisionTreeFactor product;
|
||||
BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors)
|
||||
product = (*factor) * product;
|
||||
toc(1, "product");
|
||||
// PRODUCT: multiply all factors
|
||||
tic(1, "product");
|
||||
DecisionTreeFactor product;
|
||||
BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors){
|
||||
product = (*factor) * product;
|
||||
}
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
tic(2, "sum");
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(num);
|
||||
toc(2, "sum");
|
||||
toc(1, "product");
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
tic(3, "divide");
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||
toc(3, "divide");
|
||||
tictoc_finishedIteration();
|
||||
// sum out frontals, this is the factor on the separator
|
||||
tic(2, "sum");
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(num);
|
||||
toc(2, "sum");
|
||||
|
||||
return make_pair(cond, sum);
|
||||
}
|
||||
// now divide product/sum to get conditional
|
||||
tic(3, "divide");
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||
toc(3, "divide");
|
||||
tictoc_finishedIteration();
|
||||
|
||||
return make_pair(cond, sum);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace
|
||||
|
|
|
@ -86,7 +86,6 @@ public:
|
|||
if (factors_[i] != NULL) factors_[i]->print(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
// DiscreteFactorGraph
|
||||
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/inference/GenericMultifrontalSolver.h>
|
||||
#include <gtsam/linear/GaussianBayesTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -35,27 +37,43 @@ namespace gtsam {
|
|||
|
||||
public:
|
||||
|
||||
/** Construct a marginals class.
|
||||
* @param graph The factor graph defining the full joint density on all variables.
|
||||
*/
|
||||
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
||||
}
|
||||
/** Construct a marginals class.
|
||||
* @param graph The factor graph defining the full joint density on all variables.
|
||||
*/
|
||||
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
||||
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
|
||||
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
|
||||
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
|
||||
//bayesTree_.print();
|
||||
cout << "Discrete Marginal Constructor Finished" << endl;
|
||||
}
|
||||
|
||||
/** print */
|
||||
void print(const std::string& str = "DiscreteMarginals: ") const {
|
||||
}
|
||||
/** Compute the marginal of a single variable */
|
||||
DiscreteFactor::shared_ptr operator()(Index variable) const {
|
||||
// Compute marginal
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete);
|
||||
return marginalFactor;
|
||||
}
|
||||
|
||||
/** Compute the marginal of a single variable */
|
||||
DiscreteFactor::shared_ptr operator()(Index variable) const {
|
||||
DiscreteFactor::shared_ptr p;
|
||||
return p;
|
||||
}
|
||||
/** Compute the marginal of a single variable
|
||||
* @param KEY DiscreteKey of the Variable
|
||||
* @return Vector of marginal probabilities
|
||||
*/
|
||||
Vector marginalProbabilities(const DiscreteKey& key) const {
|
||||
// Compute marginal
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete);
|
||||
|
||||
/** Compute the marginal of a single variable */
|
||||
Vector marginalProbabilities(Index variable) const {
|
||||
Vector v;
|
||||
return v;
|
||||
}
|
||||
//Create result
|
||||
Vector vResult(key.second);
|
||||
for (size_t state = 0; state < key.second ; ++ state) {
|
||||
DiscreteFactor::Values values;
|
||||
values[key.first] = state;
|
||||
vResult(state) = (*marginalFactor)(values);
|
||||
}
|
||||
return vResult;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
|
|
@ -52,6 +52,24 @@ namespace gtsam {
|
|||
|
||||
return solution;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
Vector DiscreteSequentialSolver::marginalProbabilities(
|
||||
const DiscreteKey& key) const {
|
||||
// Compute marginal
|
||||
DiscreteFactor::shared_ptr marginalFactor;
|
||||
marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete);
|
||||
|
||||
//Create result
|
||||
Vector vResult(key.second);
|
||||
for (size_t state = 0; state < key.second; ++state) {
|
||||
DiscreteFactor::Values values;
|
||||
values[key.first] = state;
|
||||
vResult(state) = (*marginalFactor)(values);
|
||||
}
|
||||
return vResult;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
}
|
||||
|
|
|
@ -74,7 +74,6 @@ namespace gtsam {
|
|||
return Base::eliminate(&EliminateDiscrete);
|
||||
}
|
||||
|
||||
#ifdef BROKEN
|
||||
/**
|
||||
* Compute the marginal joint over a set of variables, by integrating out
|
||||
* all of the other variables. This function returns the result as a factor
|
||||
|
@ -94,7 +93,13 @@ namespace gtsam {
|
|||
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
|
||||
return Base::marginalFactor(j, &EliminateDiscrete);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Compute the marginal density over a variable, by integrating out
|
||||
* all of the other variables. This function returns the result as a
|
||||
* Vector of the probability values.
|
||||
*/
|
||||
Vector marginalProbabilities(const DiscreteKey& key) const;
|
||||
|
||||
/**
|
||||
* Compute the MPE solution of the DiscreteFactorGraph. This
|
||||
|
|
|
@ -60,5 +60,28 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void Potentials::permute(const Permutation& permutation) {
|
||||
// Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
||||
DiscreteKeys keys;
|
||||
map<Index, Index> ordering;
|
||||
|
||||
// Get the orginal keys from cardinalities_
|
||||
BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
|
||||
keys & key;
|
||||
|
||||
// Perform Permutation
|
||||
BOOST_FOREACH(DiscreteKey& key, keys) {
|
||||
ordering[key.first] = permutation[key.first];
|
||||
//cout << key.first << " -> " << ordering[key.first] << endl;
|
||||
key.first = ordering[key.first];
|
||||
}
|
||||
|
||||
// Change *this
|
||||
AlgebraicDecisionTree<Index> permuted((*this), ordering);
|
||||
*this = permuted;
|
||||
cardinalities_ = keys.cardinalities();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/inference/Permutation.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <set>
|
||||
|
@ -68,6 +69,13 @@ namespace gtsam {
|
|||
|
||||
size_t cardinality(Index j) const { return cardinalities_.at(j);}
|
||||
|
||||
/*
|
||||
* @brief Permutes the keys in Potentials
|
||||
*
|
||||
* This permutes the Indices and performs necessary re-ordering of ADD.
|
||||
*/
|
||||
void permute(const Permutation& perm);
|
||||
|
||||
}; // Potentials
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
Loading…
Reference in New Issue