Merging all Discrete changes
parent
e6a0663540
commit
c31ef559f8
|
@ -127,6 +127,11 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
shared_ptr combine(size_t nrFrontals, ADT::Binary op) const;
|
||||||
|
|
||||||
|
void permuteWithInverse(const Permutation& inversePermutation){
|
||||||
|
DiscreteFactor::permuteWithInverse(inversePermutation);
|
||||||
|
Potentials::permute(inversePermutation);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DecisionTreeFactor
|
// DecisionTreeFactor
|
||||||
|
|
|
@ -46,9 +46,9 @@ namespace gtsam {
|
||||||
const DecisionTreeFactor& marginal) :
|
const DecisionTreeFactor& marginal) :
|
||||||
IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials(
|
IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials(
|
||||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
|
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) {
|
||||||
assert(nrFrontals() == 1);
|
// assert(nrFrontals() == 1);
|
||||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout
|
if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout
|
||||||
<< (firstFrontalKey()) << endl;
|
<< (firstFrontalKey()) << endl; //TODO Print all keys
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
|
@ -75,10 +75,47 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(Values& values) const {
|
void DiscreteConditional::solveInPlace(Values& values) const {
|
||||||
assert(nrFrontals() == 1);
|
// assert(nrFrontals() == 1);
|
||||||
Index j = (firstFrontalKey());
|
// Index j = (firstFrontalKey());
|
||||||
size_t mpe = solve(values); // Solve for variable
|
// size_t mpe = solve(values); // Solve for variable
|
||||||
values[j] = mpe; // store result in partial solution
|
// values[j] = mpe; // store result in partial solution
|
||||||
|
|
||||||
|
// TODO: is this really the fastest way? I think it is.
|
||||||
|
ADT pFS = choose(values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
Values mpe;
|
||||||
|
Values frontalVals;
|
||||||
|
BOOST_FOREACH(Index j, frontals()) {
|
||||||
|
frontalVals[j] = 0;
|
||||||
|
}
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t j = 0;
|
||||||
|
for (j = 0; j < nrFrontals(); j++) {
|
||||||
|
Index idx = frontals()[j];
|
||||||
|
frontalVals[idx]++;
|
||||||
|
if (frontalVals[idx] < cardinality(idx))
|
||||||
|
break;
|
||||||
|
//Wrap condition
|
||||||
|
frontalVals[idx] = 0;
|
||||||
|
}
|
||||||
|
if (j == nrFrontals())
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
//set values (inPlace) to mpe
|
||||||
|
BOOST_FOREACH(Index j, frontals()) {
|
||||||
|
values[j] = mpe[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
|
|
|
@ -98,6 +98,14 @@ namespace gtsam {
|
||||||
|
|
||||||
virtual operator DecisionTreeFactor() const = 0;
|
virtual operator DecisionTreeFactor() const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Permutes the factor, but for efficiency requires the permutation
|
||||||
|
* to already be inverted.
|
||||||
|
*/
|
||||||
|
virtual void permuteWithInverse(const Permutation& inversePermutation){
|
||||||
|
IndexFactor::permuteWithInverse(inversePermutation);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
|
@ -24,69 +24,71 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Explicitly instantiate so we don't have to include everywhere
|
// Explicitly instantiate so we don't have to include everywhere
|
||||||
template class FactorGraph<DiscreteFactor> ;
|
template class FactorGraph<DiscreteFactor> ;
|
||||||
template class EliminationTree<DiscreteFactor> ;
|
template class EliminationTree<DiscreteFactor> ;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteFactorGraph::DiscreteFactorGraph() {
|
DiscreteFactorGraph::DiscreteFactorGraph() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteFactorGraph::DiscreteFactorGraph(
|
DiscreteFactorGraph::DiscreteFactorGraph(
|
||||||
const BayesNet<DiscreteConditional>& bayesNet) :
|
const BayesNet<DiscreteConditional>& bayesNet) :
|
||||||
FactorGraph<DiscreteFactor>(bayesNet) {
|
FactorGraph<DiscreteFactor>(bayesNet) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
FastSet<Index> DiscreteFactorGraph::keys() const {
|
FastSet<Index> DiscreteFactorGraph::keys() const {
|
||||||
FastSet<Index> keys;
|
FastSet<Index> keys;
|
||||||
BOOST_FOREACH(const sharedFactor& factor, *this)
|
BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||||
if (factor) keys.insert(factor->begin(), factor->end());
|
if (factor) keys.insert(factor->begin(), factor->end());
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
DecisionTreeFactor DiscreteFactorGraph::product() const {
|
||||||
DecisionTreeFactor result;
|
DecisionTreeFactor result;
|
||||||
BOOST_FOREACH(const sharedFactor& factor, *this)
|
BOOST_FOREACH(const sharedFactor& factor, *this)
|
||||||
if (factor) result = (*factor) * result;
|
if (factor) result = (*factor) * result;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteFactorGraph::operator()(
|
double DiscreteFactorGraph::operator()(
|
||||||
const DiscreteFactor::Values &values) const {
|
const DiscreteFactor::Values &values) const {
|
||||||
double product = 1.0;
|
double product = 1.0;
|
||||||
BOOST_FOREACH( const sharedFactor& factor, factors_ )
|
BOOST_FOREACH( const sharedFactor& factor, factors_ )
|
||||||
product *= (*factor)(values);
|
product *= (*factor)(values);
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
|
EliminateDiscrete(const FactorGraph<DiscreteFactor>& factors, size_t num) {
|
||||||
|
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
tic(1, "product");
|
tic(1, "product");
|
||||||
DecisionTreeFactor product;
|
DecisionTreeFactor product;
|
||||||
BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors)
|
BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors){
|
||||||
product = (*factor) * product;
|
product = (*factor) * product;
|
||||||
toc(1, "product");
|
}
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
toc(1, "product");
|
||||||
tic(2, "sum");
|
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(num);
|
|
||||||
toc(2, "sum");
|
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// sum out frontals, this is the factor on the separator
|
||||||
tic(3, "divide");
|
tic(2, "sum");
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
DecisionTreeFactor::shared_ptr sum = product.sum(num);
|
||||||
toc(3, "divide");
|
toc(2, "sum");
|
||||||
tictoc_finishedIteration();
|
|
||||||
|
|
||||||
return make_pair(cond, sum);
|
// now divide product/sum to get conditional
|
||||||
}
|
tic(3, "divide");
|
||||||
|
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||||
|
toc(3, "divide");
|
||||||
|
tictoc_finishedIteration();
|
||||||
|
|
||||||
|
return make_pair(cond, sum);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -86,7 +86,6 @@ public:
|
||||||
if (factors_[i] != NULL) factors_[i]->print(ss.str());
|
if (factors_[i] != NULL) factors_[i]->print(ss.str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
// DiscreteFactorGraph
|
// DiscreteFactorGraph
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/inference/GenericMultifrontalSolver.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesTree.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -35,27 +37,43 @@ namespace gtsam {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** Construct a marginals class.
|
/** Construct a marginals class.
|
||||||
* @param graph The factor graph defining the full joint density on all variables.
|
* @param graph The factor graph defining the full joint density on all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
DiscreteMarginals(const DiscreteFactorGraph& graph) {
|
||||||
}
|
typedef JunctionTree<DiscreteFactorGraph> DiscreteJT;
|
||||||
|
GenericMultifrontalSolver<DiscreteFactor, DiscreteJT> solver(graph);
|
||||||
|
bayesTree_ = *solver.eliminate(&EliminateDiscrete);
|
||||||
|
//bayesTree_.print();
|
||||||
|
cout << "Discrete Marginal Constructor Finished" << endl;
|
||||||
|
}
|
||||||
|
|
||||||
/** print */
|
/** Compute the marginal of a single variable */
|
||||||
void print(const std::string& str = "DiscreteMarginals: ") const {
|
DiscreteFactor::shared_ptr operator()(Index variable) const {
|
||||||
}
|
// Compute marginal
|
||||||
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
|
marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete);
|
||||||
|
return marginalFactor;
|
||||||
|
}
|
||||||
|
|
||||||
/** Compute the marginal of a single variable */
|
/** Compute the marginal of a single variable
|
||||||
DiscreteFactor::shared_ptr operator()(Index variable) const {
|
* @param KEY DiscreteKey of the Variable
|
||||||
DiscreteFactor::shared_ptr p;
|
* @return Vector of marginal probabilities
|
||||||
return p;
|
*/
|
||||||
}
|
Vector marginalProbabilities(const DiscreteKey& key) const {
|
||||||
|
// Compute marginal
|
||||||
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
|
marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete);
|
||||||
|
|
||||||
/** Compute the marginal of a single variable */
|
//Create result
|
||||||
Vector marginalProbabilities(Index variable) const {
|
Vector vResult(key.second);
|
||||||
Vector v;
|
for (size_t state = 0; state < key.second ; ++ state) {
|
||||||
return v;
|
DiscreteFactor::Values values;
|
||||||
}
|
values[key.first] = state;
|
||||||
|
vResult(state) = (*marginalFactor)(values);
|
||||||
|
}
|
||||||
|
return vResult;
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -52,6 +52,24 @@ namespace gtsam {
|
||||||
|
|
||||||
return solution;
|
return solution;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
Vector DiscreteSequentialSolver::marginalProbabilities(
|
||||||
|
const DiscreteKey& key) const {
|
||||||
|
// Compute marginal
|
||||||
|
DiscreteFactor::shared_ptr marginalFactor;
|
||||||
|
marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete);
|
||||||
|
|
||||||
|
//Create result
|
||||||
|
Vector vResult(key.second);
|
||||||
|
for (size_t state = 0; state < key.second; ++state) {
|
||||||
|
DiscreteFactor::Values values;
|
||||||
|
values[key.first] = state;
|
||||||
|
vResult(state) = (*marginalFactor)(values);
|
||||||
|
}
|
||||||
|
return vResult;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,7 +74,6 @@ namespace gtsam {
|
||||||
return Base::eliminate(&EliminateDiscrete);
|
return Base::eliminate(&EliminateDiscrete);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef BROKEN
|
|
||||||
/**
|
/**
|
||||||
* Compute the marginal joint over a set of variables, by integrating out
|
* Compute the marginal joint over a set of variables, by integrating out
|
||||||
* all of the other variables. This function returns the result as a factor
|
* all of the other variables. This function returns the result as a factor
|
||||||
|
@ -94,7 +93,13 @@ namespace gtsam {
|
||||||
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
|
DiscreteFactor::shared_ptr marginalFactor(Index j) const {
|
||||||
return Base::marginalFactor(j, &EliminateDiscrete);
|
return Base::marginalFactor(j, &EliminateDiscrete);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
/**
|
||||||
|
* Compute the marginal density over a variable, by integrating out
|
||||||
|
* all of the other variables. This function returns the result as a
|
||||||
|
* Vector of the probability values.
|
||||||
|
*/
|
||||||
|
Vector marginalProbabilities(const DiscreteKey& key) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the MPE solution of the DiscreteFactorGraph. This
|
* Compute the MPE solution of the DiscreteFactorGraph. This
|
||||||
|
|
|
@ -60,5 +60,28 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
void Potentials::permute(const Permutation& permutation) {
|
||||||
|
// Permute the _cardinalities (TODO: Inefficient Consider Improving)
|
||||||
|
DiscreteKeys keys;
|
||||||
|
map<Index, Index> ordering;
|
||||||
|
|
||||||
|
// Get the orginal keys from cardinalities_
|
||||||
|
BOOST_FOREACH(const DiscreteKey& key, cardinalities_)
|
||||||
|
keys & key;
|
||||||
|
|
||||||
|
// Perform Permutation
|
||||||
|
BOOST_FOREACH(DiscreteKey& key, keys) {
|
||||||
|
ordering[key.first] = permutation[key.first];
|
||||||
|
//cout << key.first << " -> " << ordering[key.first] << endl;
|
||||||
|
key.first = ordering[key.first];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change *this
|
||||||
|
AlgebraicDecisionTree<Index> permuted((*this), ordering);
|
||||||
|
*this = permuted;
|
||||||
|
cardinalities_ = keys.cardinalities();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
|
#include <gtsam/inference/Permutation.h>
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
@ -68,6 +69,13 @@ namespace gtsam {
|
||||||
|
|
||||||
size_t cardinality(Index j) const { return cardinalities_.at(j);}
|
size_t cardinality(Index j) const { return cardinalities_.at(j);}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @brief Permutes the keys in Potentials
|
||||||
|
*
|
||||||
|
* This permutes the Indices and performs necessary re-ordering of ADD.
|
||||||
|
*/
|
||||||
|
void permute(const Permutation& perm);
|
||||||
|
|
||||||
}; // Potentials
|
}; // Potentials
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
Loading…
Reference in New Issue