diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 537bb3e60..a29a5683a 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -127,6 +127,11 @@ namespace gtsam { */ shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; + void permuteWithInverse(const Permutation& inversePermutation){ + DiscreteFactor::permuteWithInverse(inversePermutation); + Potentials::permute(inversePermutation); + } + /// @} }; // DecisionTreeFactor diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 73649dce6..ca8f13ca3 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -46,9 +46,9 @@ namespace gtsam { const DecisionTreeFactor& marginal) : IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials( ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) { - assert(nrFrontals() == 1); +// assert(nrFrontals() == 1); if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout - << (firstFrontalKey()) << endl; + << (firstFrontalKey()) << endl; //TODO Print all keys } /* ******************************************************************************** */ @@ -75,10 +75,47 @@ namespace gtsam { /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(Values& values) const { - assert(nrFrontals() == 1); - Index j = (firstFrontalKey()); - size_t mpe = solve(values); // Solve for variable - values[j] = mpe; // store result in partial solution +// assert(nrFrontals() == 1); +// Index j = (firstFrontalKey()); +// size_t mpe = solve(values); // Solve for variable +// values[j] = mpe; // store result in partial solution + + // TODO: is this really the fastest way? I think it is. + ADT pFS = choose(values); // P(F|S=parentsValues) + + // Initialize + Values mpe; + Values frontalVals; + BOOST_FOREACH(Index j, frontals()) { + frontalVals[j] = 0; + } + double maxP = 0; + + while (1) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + + size_t j = 0; + for (j = 0; j < nrFrontals(); j++) { + Index idx = frontals()[j]; + frontalVals[idx]++; + if (frontalVals[idx] < cardinality(idx)) + break; + //Wrap condition + frontalVals[idx] = 0; + } + if (j == nrFrontals()) + break; + } + + //set values (inPlace) to mpe + BOOST_FOREACH(Index j, frontals()) { + values[j] = mpe[j]; + } } /* ******************************************************************************** */ diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 3789389e4..32e79568c 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -98,6 +98,14 @@ namespace gtsam { virtual operator DecisionTreeFactor() const = 0; + /** + * Permutes the factor, but for efficiency requires the permutation + * to already be inverted. + */ + virtual void permuteWithInverse(const Permutation& inversePermutation){ + IndexFactor::permuteWithInverse(inversePermutation); + } + /// @} }; // DiscreteFactor diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index f5fc7533d..0d5e01c29 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -24,69 +24,71 @@ namespace gtsam { -// Explicitly instantiate so we don't have to include everywhere -template class FactorGraph ; -template class EliminationTree ; + // Explicitly instantiate so we don't have to include everywhere + template class FactorGraph ; + template class EliminationTree ; -/* ************************************************************************* */ -DiscreteFactorGraph::DiscreteFactorGraph() { -} + /* ************************************************************************* */ + DiscreteFactorGraph::DiscreteFactorGraph() { + } -/* ************************************************************************* */ -DiscreteFactorGraph::DiscreteFactorGraph( - const BayesNet& bayesNet) : - FactorGraph(bayesNet) { -} + /* ************************************************************************* */ + DiscreteFactorGraph::DiscreteFactorGraph( + const BayesNet& bayesNet) : + FactorGraph(bayesNet) { + } -/* ************************************************************************* */ -FastSet DiscreteFactorGraph::keys() const { - FastSet keys; - BOOST_FOREACH(const sharedFactor& factor, *this) - if (factor) keys.insert(factor->begin(), factor->end()); - return keys; -} + /* ************************************************************************* */ + FastSet DiscreteFactorGraph::keys() const { + FastSet keys; + BOOST_FOREACH(const sharedFactor& factor, *this) + if (factor) keys.insert(factor->begin(), factor->end()); + return keys; + } -/* ************************************************************************* */ -DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - BOOST_FOREACH(const sharedFactor& factor, *this) - if (factor) result = (*factor) * result; - return result; -} + /* ************************************************************************* */ + DecisionTreeFactor DiscreteFactorGraph::product() const { + DecisionTreeFactor result; + BOOST_FOREACH(const sharedFactor& factor, *this) + if (factor) result = (*factor) * result; + return result; + } -/* ************************************************************************* */ -double DiscreteFactorGraph::operator()( - const DiscreteFactor::Values &values) const { - double product = 1.0; - BOOST_FOREACH( const sharedFactor& factor, factors_ ) - product *= (*factor)(values); - return product; -} + /* ************************************************************************* */ + double DiscreteFactorGraph::operator()( + const DiscreteFactor::Values &values) const { + double product = 1.0; + BOOST_FOREACH( const sharedFactor& factor, factors_ ) + product *= (*factor)(values); + return product; + } -/* ************************************************************************* */ -pair // -EliminateDiscrete(const FactorGraph& factors, size_t num) { + /* ************************************************************************* */ + pair // + EliminateDiscrete(const FactorGraph& factors, size_t num) { - // PRODUCT: multiply all factors - tic(1, "product"); - DecisionTreeFactor product; - BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors) - product = (*factor) * product; - toc(1, "product"); + // PRODUCT: multiply all factors + tic(1, "product"); + DecisionTreeFactor product; + BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors){ + product = (*factor) * product; + } - // sum out frontals, this is the factor on the separator - tic(2, "sum"); - DecisionTreeFactor::shared_ptr sum = product.sum(num); - toc(2, "sum"); + toc(1, "product"); - // now divide product/sum to get conditional - tic(3, "divide"); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); - toc(3, "divide"); - tictoc_finishedIteration(); + // sum out frontals, this is the factor on the separator + tic(2, "sum"); + DecisionTreeFactor::shared_ptr sum = product.sum(num); + toc(2, "sum"); - return make_pair(cond, sum); -} + // now divide product/sum to get conditional + tic(3, "divide"); + DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); + toc(3, "divide"); + tictoc_finishedIteration(); + + return make_pair(cond, sum); + } /* ************************************************************************* */ } // namespace diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index eefdfaf1d..d96a3049a 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -86,7 +86,6 @@ public: if (factors_[i] != NULL) factors_[i]->print(ss.str()); } } - }; // DiscreteFactorGraph diff --git a/gtsam/discrete/DiscreteMarginals.h b/gtsam/discrete/DiscreteMarginals.h index f877a0128..99176e58c 100644 --- a/gtsam/discrete/DiscreteMarginals.h +++ b/gtsam/discrete/DiscreteMarginals.h @@ -21,6 +21,8 @@ #pragma once #include +#include +#include namespace gtsam { @@ -35,27 +37,43 @@ namespace gtsam { public: - /** Construct a marginals class. - * @param graph The factor graph defining the full joint density on all variables. - */ - DiscreteMarginals(const DiscreteFactorGraph& graph) { - } + /** Construct a marginals class. + * @param graph The factor graph defining the full joint density on all variables. + */ + DiscreteMarginals(const DiscreteFactorGraph& graph) { + typedef JunctionTree DiscreteJT; + GenericMultifrontalSolver solver(graph); + bayesTree_ = *solver.eliminate(&EliminateDiscrete); + //bayesTree_.print(); + cout << "Discrete Marginal Constructor Finished" << endl; + } - /** print */ - void print(const std::string& str = "DiscreteMarginals: ") const { - } + /** Compute the marginal of a single variable */ + DiscreteFactor::shared_ptr operator()(Index variable) const { + // Compute marginal + DiscreteFactor::shared_ptr marginalFactor; + marginalFactor = bayesTree_.marginalFactor(variable, &EliminateDiscrete); + return marginalFactor; + } - /** Compute the marginal of a single variable */ - DiscreteFactor::shared_ptr operator()(Index variable) const { - DiscreteFactor::shared_ptr p; - return p; - } + /** Compute the marginal of a single variable + * @param KEY DiscreteKey of the Variable + * @return Vector of marginal probabilities + */ + Vector marginalProbabilities(const DiscreteKey& key) const { + // Compute marginal + DiscreteFactor::shared_ptr marginalFactor; + marginalFactor = bayesTree_.marginalFactor(key.first, &EliminateDiscrete); - /** Compute the marginal of a single variable */ - Vector marginalProbabilities(Index variable) const { - Vector v; - return v; - } + //Create result + Vector vResult(key.second); + for (size_t state = 0; state < key.second ; ++ state) { + DiscreteFactor::Values values; + values[key.first] = state; + vResult(state) = (*marginalFactor)(values); + } + return vResult; + } }; diff --git a/gtsam/discrete/DiscreteSequentialSolver.cpp b/gtsam/discrete/DiscreteSequentialSolver.cpp index 1ca00875a..e4bc502c8 100644 --- a/gtsam/discrete/DiscreteSequentialSolver.cpp +++ b/gtsam/discrete/DiscreteSequentialSolver.cpp @@ -52,6 +52,24 @@ namespace gtsam { return solution; } + + /* ************************************************************************* */ + Vector DiscreteSequentialSolver::marginalProbabilities( + const DiscreteKey& key) const { + // Compute marginal + DiscreteFactor::shared_ptr marginalFactor; + marginalFactor = Base::marginalFactor(key.first, &EliminateDiscrete); + + //Create result + Vector vResult(key.second); + for (size_t state = 0; state < key.second; ++state) { + DiscreteFactor::Values values; + values[key.first] = state; + vResult(state) = (*marginalFactor)(values); + } + return vResult; + } + /* ************************************************************************* */ } diff --git a/gtsam/discrete/DiscreteSequentialSolver.h b/gtsam/discrete/DiscreteSequentialSolver.h index c453f5b96..6adb32ef6 100644 --- a/gtsam/discrete/DiscreteSequentialSolver.h +++ b/gtsam/discrete/DiscreteSequentialSolver.h @@ -74,7 +74,6 @@ namespace gtsam { return Base::eliminate(&EliminateDiscrete); } -#ifdef BROKEN /** * Compute the marginal joint over a set of variables, by integrating out * all of the other variables. This function returns the result as a factor @@ -94,7 +93,13 @@ namespace gtsam { DiscreteFactor::shared_ptr marginalFactor(Index j) const { return Base::marginalFactor(j, &EliminateDiscrete); } -#endif + + /** + * Compute the marginal density over a variable, by integrating out + * all of the other variables. This function returns the result as a + * Vector of the probability values. + */ + Vector marginalProbabilities(const DiscreteKey& key) const; /** * Compute the MPE solution of the DiscreteFactorGraph. This diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp index ac6ecde10..e6f07ccac 100644 --- a/gtsam/discrete/Potentials.cpp +++ b/gtsam/discrete/Potentials.cpp @@ -60,5 +60,28 @@ namespace gtsam { } /* ************************************************************************* */ + void Potentials::permute(const Permutation& permutation) { + // Permute the _cardinalities (TODO: Inefficient Consider Improving) + DiscreteKeys keys; + map ordering; + + // Get the orginal keys from cardinalities_ + BOOST_FOREACH(const DiscreteKey& key, cardinalities_) + keys & key; + + // Perform Permutation + BOOST_FOREACH(DiscreteKey& key, keys) { + ordering[key.first] = permutation[key.first]; + //cout << key.first << " -> " << ordering[key.first] << endl; + key.first = ordering[key.first]; + } + + // Change *this + AlgebraicDecisionTree permuted((*this), ordering); + *this = permuted; + cardinalities_ = keys.cardinalities(); + } + + /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h index 8a8c2e3bc..d8c5d9c3b 100644 --- a/gtsam/discrete/Potentials.h +++ b/gtsam/discrete/Potentials.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -68,6 +69,13 @@ namespace gtsam { size_t cardinality(Index j) const { return cardinalities_.at(j);} + /* + * @brief Permutes the keys in Potentials + * + * This permutes the Indices and performs necessary re-ordering of ADD. + */ + void permute(const Permutation& perm); + }; // Potentials } // namespace gtsam