Merge pull request #1050 from borglab/feature/betterMPE

release/4.3a0
Frank Dellaert 2022-01-22 14:49:00 -05:00 committed by GitHub
commit b441eea976
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 738 additions and 346 deletions

View File

@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate // Create solver and eliminate
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
// solve // solve
auto mpe = chordal->optimize(); auto mpe = fg.optimize();
GTSAM_PRINT(mpe); GTSAM_PRINT(mpe);
// We can also build a Bayes tree (directed junction tree). // We can also build a Bayes tree (directed junction tree).
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); auto mpe2 = fg.optimize();
auto mpe2 = chordal2->optimize();
GTSAM_PRINT(mpe2); GTSAM_PRINT(mpe2);
// We can also sample from it // We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample(); auto sample = chordal->sample();
GTSAM_PRINT(sample); GTSAM_PRINT(sample);
} }
return 0; return 0;

View File

@ -85,7 +85,7 @@ int main(int argc, char **argv) {
} }
// "Most Probable Explanation", i.e., configuration with largest value // "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize(); auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl; cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe); print(mpe);
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0"); graph.add(Cloudy, "1 0");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto mpe_with_evidence = graph.optimize();
auto mpe_with_evidence = chordal->optimize();
cout << "\nMPE given C=0:" << endl; cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence); print(mpe_with_evidence);
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl; << endl;
// We can also sample from it // We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample(); auto sample = chordal->sample();

View File

@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph // Convert to factor graph
DiscreteFactorGraph factorGraph(hmm); DiscreteFactorGraph factorGraph(hmm);
// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);
// Create solver and eliminate // Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed // This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal = DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering); factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated"); chordal->print("Eliminated");
// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);
// We can also sample from it // We can also sample from it
cout << "\n10 samples:" << endl; cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) { for (size_t k = 0; k < 10; k++) {

View File

@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge)."; << graph.size() << " factors (Unary+Edge).";
// "Decoding", i.e., configuration with largest value // "Decoding", i.e., configuration with largest value
// We use sequential variable elimination // Uses max-product.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
// "Inference" Computing marginals for each node // "Inference" Computing marginals for each node

View File

@ -61,9 +61,8 @@ int main(int argc, char** argv) {
} }
// "Decoding", i.e., configuration with largest value (MPE) // "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination // Uses max-product
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto optimalDecoding = graph.optimize();
auto optimalDecoding = chordal->optimize();
GTSAM_PRINT(optimalDecoding); GTSAM_PRINT(optimalDecoding);
// "Inference" Computing marginals // "Inference" Computing marginals

View File

@ -22,6 +22,7 @@
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility> #include <utility>
using namespace std; using namespace std;
@ -65,9 +66,13 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s; cout << s;
ADT::print("Potentials:",formatter); cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add); return combine(keys, ADT::Ring::add);
} }
/// Create new factor by maximizing over all values with the same separator values /// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const { shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max); return combine(nrFrontals, ADT::Ring::max);
} }
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result; DiscreteValues result;
return optimize(result); return optimize(result);
@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first) // solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this)) for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result); conditional->solveInPlace(&result);
return result; return result;
} }
#endif
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample() const {

View File

@ -31,12 +31,12 @@
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Discrete densities */ /** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{ {
public: public:
typedef FactorGraph<DiscreteConditional> Base; typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This; typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType; typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -45,7 +45,7 @@ namespace gtsam {
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
/** Construct empty factor graph */ /// Construct empty Bayes net.
DiscreteBayesNet() {} DiscreteBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
@ -98,27 +98,6 @@ namespace gtsam {
return evaluate(values); return evaluate(values);
} }
/**
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;
/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;
/** /**
* @brief do ancestral sampling * @brief do ancestral sampling
* *
@ -152,7 +131,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const; const DiscreteFactor::Names& names = {}) const;
///@}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @} /// @}
#endif
private: private:
/** Serialization function */ /** Serialization function */

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm> #include <algorithm>
#include <boost/make_shared.hpp>
#include <random> #include <random>
#include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <set> #include <vector>
using namespace std; using namespace std;
using std::pair;
using std::stringstream; using std::stringstream;
using std::vector; using std::vector;
using std::pair;
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
cout << endl; cout << endl;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
} }
/* ************************************************************************** */ /* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, const DiscreteValues& given, bool forceComplete) const {
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables. // branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional); DiscreteConditional::ADT adt(*this);
size_t value; size_t value;
for (Key j : conditional.parents()) { for (Key j : parents()) {
try { try {
value = given.at(j); value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller. adt = adt.choose(j, value); // ADT keeps getting smaller.
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
if (forceComplete) { if (forceComplete) {
given.print("parentsValues: "); given.print("parentsValues: ");
throw runtime_error( throw runtime_error(
"DiscreteConditional::Choose: parent value missing"); "DiscreteConditional::choose: parent value missing");
} }
} }
} }
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose( DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const { const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given) ADT adt = choose(given, false); // P(F|S=given)
// Collect all keys not in given. // Collect all keys not in given.
DiscreteKeys dKeys; DiscreteKeys dKeys;
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
} }
/* ******************************************************************************** */ /* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const { size_t parent_value) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
} }
/* ************************************************************************** */ /* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -248,59 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// Get all Possible Configurations // Get all Possible Configurations
const auto allPosbValues = frontalAssignments(); const auto allPosbValues = frontalAssignments();
// Find the MPE // Find the maximum
for (const auto& frontalVals : allPosbValues) { for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better // Update maximum solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = frontalVals; mpe = frontalVals;
} }
} }
// set values (inPlace) to mpe // set values (inPlace) to maximum
for (Key j : frontals()) { for (Key j : frontals()) {
(*values)[j] = mpe[j]; (*values)[j] = mpe[j];
} }
} }
/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way size_t max = 0;
size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0; double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value; frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues) double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
max = value;
}
}
return max;
}
#endif
/* ************************************************************************** */
size_t DiscreteConditional::argmax() const {
size_t maxValue = 0;
double maxP = 0;
assert(nrFrontals() == 1);
assert(nrParents() == 0);
DiscreteValues frontals;
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = (*this)(frontals);
// Update MPE solution if better // Update MPE solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = value; maxValue = value;
} }
} }
return mpe; return maxValue;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
@ -323,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng); return distribution(rng);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
@ -334,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values); return sample(values);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample() const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional(const DiscreteKey& key, const std::string& spec) DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {} : DiscreteConditional(Signature(key, {}, spec)) {}
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/ */
DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal); const DecisionTreeFactor& marginal);
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys. * Makes sure the keys are ordered as given. Does not check orderedKeys.
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values); return ADT::operator()(values);
} }
/** /**
* @brief restrict to given *parent* values. * @brief restrict to given *parent* values.
* *
* Note: does not need be complete set. Examples: * Note: does not need be complete set. Examples:
* *
* P(C|D,E) + . -> P(C|D,E) * P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D) * P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E) * P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C) * P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error! * P(C|D,E) + C -> error!
* *
* @return a shared_ptr to a new DiscreteConditional * @return a shared_ptr to a new DiscreteConditional
*/ */
shared_ptr choose(const DiscreteValues& given) const; shared_ptr choose(const DiscreteValues& given) const;
@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
/** Single variable version of likelihood. */ /** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
/**
* solve a conditional
* @param parentsValues Known values of the parents
* @return MPE value of the child (1 frontal variable).
*/
size_t solve(const DiscreteValues& parentsValues) const;
/** /**
* sample * sample
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
@ -199,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
/// Zero parent version. /// Zero parent version.
size_t sample() const; size_t sample() const;
/**
* @brief Return assignment that maximizes distribution.
* @return Optimal assignment (1 frontal variable).
*/
size_t argmax() const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
/// solve a conditional, in place
void solveInPlace(DiscreteValues* parentsValues) const;
/// sample in place, stores result in partial solution /// sample in place, stores result in partial solution
void sampleInPlace(DiscreteValues* parentsValues) const; void sampleInPlace(DiscreteValues* parentsValues) const;
@ -228,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
/// Return entire probability mass function. /// Return entire probability mass function.
std::vector<double> pmf() const; std::vector<double> pmf() const;
/**
* solve a conditional
* @return MPE value of the child (1 frontal variable).
*/
size_t solve() const { return Base::solve({}); }
/**
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample(); }
/// @} /// @}
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
/// @}
#endif
}; };
// DiscreteDistribution // DiscreteDistribution

View File

@ -21,6 +21,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph-inst.h> #include <gtsam/inference/EliminateableFactorGraph-inst.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
@ -95,22 +96,85 @@ namespace gtsam {
// } // }
// } // }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize() const // Alternate eliminate function for MPE
{
gttic(DiscreteFactorGraph_optimize);
return BaseEliminateable::eliminateSequential()->optimize();
}
/* ************************************************************************* */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
DecisionTreeFactor product; DecisionTreeFactor product;
for(const DiscreteFactor::shared_ptr& factor: factors) for (auto&& factor : factors) product = (*factor) * product;
product = (*factor) * product; gttoc(product);
// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
gttoc(max);
// Ordering keys for the conditional so that frontalKeys are really in front
DiscreteKeys orderedKeys;
for (auto&& key : frontalKeys)
orderedKeys.emplace_back(key, product.cardinality(key));
for (auto&& key : max->keys())
orderedKeys.emplace_back(key, product.cardinality(key));
// Make lookup with product
gttic(lookup);
size_t nrFrontals = frontalKeys.size();
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
orderedKeys, product);
gttoc(lookup);
return std::make_pair(
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
}
/* ************************************************************************ */
// The max-product solution below is a bit clunky: the elimination machinery
// does not allow for differently *typed* versions of elimination, so we
// eliminate into a Bayes Net using the special eliminate function above, and
// then create the DiscreteLookupDAG after the fact, in linear time.
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet =
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_maxProduct);
auto bayesNet =
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
}
/* ************************************************************************ */
DiscreteValues DiscreteFactorGraph::optimize(
OptionalOrderingType orderingType) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(orderingType);
return dag.argmax();
}
DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
}
/* ************************************************************************ */
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product); gttoc(product);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
@ -120,15 +184,18 @@ namespace gtsam {
// Ordering keys for the conditional so that frontalKeys are really in front // Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys; Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
sum->keys().end());
// now divide product/sum to get conditional // now divide product/sum to get conditional
gttic(divide); gttic(divide);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); auto conditional =
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
gttoc(divide); gttoc(divide);
return std::make_pair(cond, sum); return std::make_pair(conditional, sum);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -18,10 +18,11 @@
#pragma once #pragma once
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
@ -128,18 +129,39 @@ class GTSAM_EXPORT DiscreteFactorGraph
const std::string& s = "DiscreteFactorGraph", const std::string& s = "DiscreteFactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/** Solve the factor graph by performing variable elimination in COLAMD order using /**
* the dense elimination function specified in \c function, * @brief Implement the max-product algorithm
* followed by back-substitution resulting from elimination. Is equivalent *
* to calling graph.eliminateSequential()->optimize(). */ * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
DiscreteValues optimize() const; * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Implement the max-product algorithm
*
* @param ordering
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables
*/
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
// /** Permute the variables in the factors */ /**
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); * @brief Find the maximum probable explanation (MPE) by doing max-product.
// *
// /** Apply a reduction, which is a remapping of variable indices. */ * @param orderingType
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); * @return DiscreteValues : MPE
*/
DiscreteValues optimize(
OptionalOrderingType orderingType = boost::none) const;
/**
* @brief Find the maximum probable explanation (MPE) by doing max-product.
*
* @param ordering
* @return DiscreteValues : MPE
*/
DiscreteValues optimize(const Ordering& ordering) const;
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -33,16 +33,13 @@ namespace gtsam {
KeyVector DiscreteKeys::indices() const { KeyVector DiscreteKeys::indices() const {
KeyVector js; KeyVector js;
for(const DiscreteKey& key: *this) for (const DiscreteKey& key : *this) js.push_back(key.first);
js.push_back(key.first);
return js; return js;
} }
map<Key,size_t> DiscreteKeys::cardinalities() const { map<Key, size_t> DiscreteKeys::cardinalities() const {
map<Key,size_t> cs; map<Key, size_t> cs;
cs.insert(begin(),end()); cs.insert(begin(), end());
// for(const DiscreteKey& key: *this)
// cs.insert(key);
return cs; return cs;
} }

View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
void DiscreteLookupTable::print(const std::string& s,
const KeyFormatter& formatter) const {
using std::cout;
using std::endl;
cout << s << " g( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "; ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
ADT::print("", formatter);
cout << endl;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
}
}
return mpe;
}
/* ************************************************************************** */
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
const DiscreteBayesNet& bayesNet) {
DiscreteLookupDAG dag;
for (auto&& conditional : bayesNet) {
if (auto lookupTable =
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
dag.push_back(lookupTable);
} else {
throw std::runtime_error(
"DiscreteFactorGraph::maxProduct: Expected look up table.");
}
}
return dag;
}
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
/* ************************************************************************** */
} // namespace gtsam

View File

@ -0,0 +1,140 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.h
* @date January, 2022
* @author Frank dellaert
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
class DiscreteBayesNet;
/**
* @brief DiscreteLookupTable table for max-product
*
* Inherits from discrete conditional for convenience, but is not normalized.
* Is used in the max-product algorithm.
*/
class DiscreteLookupTable : public DiscreteConditional {
public:
using This = DiscreteLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a orted list of gtsam::Keys
* @param potentials the algebraic decision tree with lookup values
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}
/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/**
* @brief return assignment for single frontal variable that maximizes value.
* @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/
size_t argmax(const DiscreteValues& parentsValues) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
};
/** A DAG made from lookup tables, as defined above. */
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
public:
using Base = BayesNet<DiscreteLookupTable>;
using This = DiscreteLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
DiscreteLookupDAG() {}
/// Create from BayesNet with LookupTables
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
/// Destructor
virtual ~DiscreteLookupDAG() {}
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This& bn, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
} // namespace gtsam

View File

@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DecisionTreeFactor* likelihood( gtsam::DecisionTreeFactor* likelihood(
const gtsam::DiscreteValues& frontalValues) const; const gtsam::DiscreteValues& frontalValues) const;
gtsam::DecisionTreeFactor* likelihood(size_t value) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const;
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const; size_t sample(size_t value) const;
size_t sample() const; size_t sample() const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(size_t value) const; double operator()(size_t value) const;
std::vector<double> pmf() const; std::vector<double> pmf() const;
size_t solve() const; size_t argmax() const;
}; };
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
@ -163,8 +161,6 @@ class DiscreteBayesNet {
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
@ -217,6 +213,21 @@ class DiscreteBayesTree {
std::map<gtsam::Key, std::vector<std::string>> names) const; std::map<gtsam::Key, std::vector<std::string>> names) const;
}; };
#include <gtsam/discrete/DiscreteLookupDAG.h>
class DiscreteLookupDAG {
DiscreteLookupDAG();
void push_back(const gtsam::DiscreteLookupTable* table);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::DiscreteLookupTable* at(size_t i) const;
void print(string s = "DiscreteLookupDAG\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
gtsam::DiscreteValues argmax() const;
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
};
#include <gtsam/inference/DotWriter.h> #include <gtsam/inference/DotWriter.h>
class DotWriter { class DotWriter {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
@ -260,6 +271,9 @@ class DiscreteFactorGraph {
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteLookupDAG maxProduct();
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential();
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph> std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>

View File

@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
DiscreteConditional expected2(Bronchitis % "11/9"); DiscreteConditional expected2(Bronchitis % "11/9");
EXPECT(assert_equal(expected2, *chordal->back())); EXPECT(assert_equal(expected2, *chordal->back()));
// solve
auto actualMPE = chordal->optimize();
DiscreteValues expectedMPE;
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 0);
EXPECT(assert_equal(expectedMPE, actualMPE));
// add evidence, we were in Asia and we have dyspnea // add evidence, we were in Asia and we have dyspnea
fg.add(Asia, "0 1"); fg.add(Asia, "0 1");
fg.add(Dyspnea, "0 1"); fg.add(Dyspnea, "0 1");
// solve again, now with evidence // solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto actualMPE2 = chordal2->optimize(); EXPECT(assert_equal(expected2, *chordal->back()));
DiscreteValues expectedMPE2;
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
LungCancer.first, 0)(Bronchitis.first, 1);
EXPECT(assert_equal(expectedMPE2, actualMPE2));
// now sample from it // now sample from it
DiscreteValues expectedSample; DiscreteValues expectedSample;

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/* /*
* @file testDiscretePrior.cpp * @file testDiscreteDistribution.cpp
* @brief unit tests for DiscreteDistribution * @brief unit tests for DiscreteDistribution
* @author Frank dellaert * @author Frank dellaert
* @date December 2021 * @date December 2021
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
prior.sample(); prior.sample();
} }
/* ************************************************************************* */
TEST(DiscreteDistribution, argmax) {
DiscreteDistribution prior(X % "2/3");
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -30,8 +30,8 @@ using namespace std;
using namespace gtsam; using namespace gtsam;
/* ************************************************************************* */ /* ************************************************************************* */
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(AI, "1 0 0 1"); graph.add(AI, "1 0 0 1");
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
// graph.print("Graph: "); // Check MPE.
DecisionTreeFactor product = graph.product(); auto actualMPE = graph.optimize();
DecisionTreeFactor::shared_ptr sum = product.sum(1); DiscreteValues mpe;
// sum->print("Debug SUM: "); insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); EXPECT(assert_equal(mpe, actualMPE));
// cond->print("marginal:");
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
// result.first->print("BayesNet: ");
// result.second->print("New factor: ");
//
Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3);
DiscreteEliminationTree eliminationTree(graph, ordering);
// eliminationTree.print("Elimination tree: ");
eliminationTree.eliminate(EliminateDiscrete);
// solver.optimize();
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, test) TEST(DiscreteFactorGraph, test) {
{
// Declare keys and ordering // Declare keys and ordering
DiscreteKey C(0,2), B(1,2), A(2,2); DiscreteKey C(0, 2), B(1, 2), A(2, 2);
// A simple factor graph (A)-fAC-(C)-fBC-(B) // A simple factor graph (A)-fAC-(C)-fBC-(B)
// with smoothness priors // with smoothness priors
@ -127,77 +112,109 @@ TEST( DiscreteFactorGraph, test)
graph.add(C & B, "3 1 1 3"); graph.add(C & B, "3 1 1 3");
// Test EliminateDiscrete // Test EliminateDiscrete
// FIXME: apparently Eliminate returns a conditional rather than a net
Ordering frontalKeys; Ordering frontalKeys;
frontalKeys += Key(0); frontalKeys += Key(0);
DiscreteConditional::shared_ptr conditional; DiscreteConditional::shared_ptr conditional;
DecisionTreeFactor::shared_ptr newFactor; DecisionTreeFactor::shared_ptr newFactor;
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
// Check Bayes net // Check Conditional
CHECK(conditional); CHECK(conditional);
DiscreteBayesNet expected;
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
// cout << signature << endl;
DiscreteConditional expectedConditional(signature); DiscreteConditional expectedConditional(signature);
EXPECT(assert_equal(expectedConditional, *conditional)); EXPECT(assert_equal(expectedConditional, *conditional));
expected.add(signature);
// Check Factor // Check Factor
CHECK(newFactor); CHECK(newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
EXPECT(assert_equal(expectedFactor, *newFactor)); EXPECT(assert_equal(expectedFactor, *newFactor));
// add conditionals to complete expected Bayes net // Test using elimination tree
expected.add(B | A = "5/3 3/5");
expected.add(A % "1/1");
// GTSAM_PRINT(expected);
// Test elimination tree
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2); ordering += Key(0), Key(1), Key(2);
DiscreteEliminationTree etree(graph, ordering); DiscreteEliminationTree etree(graph, ordering);
DiscreteBayesNet::shared_ptr actual; DiscreteBayesNet::shared_ptr actual;
DiscreteFactorGraph::shared_ptr remainingGraph; DiscreteFactorGraph::shared_ptr remainingGraph;
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
EXPECT(assert_equal(expected, *actual));
// // Test solver // Check Bayes net
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); DiscreteBayesNet expectedBayesNet;
// EXPECT(assert_equal(expected, *actual2)); expectedBayesNet.add(signature);
expectedBayesNet.add(B | A = "5/3 3/5");
expectedBayesNet.add(A % "1/1");
EXPECT(assert_equal(expectedBayesNet, *actual));
// Test optimization // Test eliminateSequential
DiscreteValues expectedValues; DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
insert(expectedValues)(0, 0)(1, 0)(2, 0); EXPECT(assert_equal(expectedBayesNet, *actual2));
auto actualValues = graph.optimize();
EXPECT(assert_equal(expectedValues, actualValues)); // Test mpe
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 0)(2, 0);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE) TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
{
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey C(0,2), A(1,2), B(2,2); DiscreteKey C(0, 2), A(1, 2), B(2, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & A, "0.2 0.8 0.3 0.7");
graph.add(C & B, "0.1 0.9 0.4 0.6"); graph.add(C & B, "0.1 0.9 0.4 0.6");
// graph.product().print();
// DiscreteSequentialSolver(graph).eliminate()->print();
auto actualMPE = graph.optimize(); // Created expected MPE
DiscreteValues mpe;
insert(mpe)(0, 0)(1, 1)(2, 1);
DiscreteValues expectedMPE; // Do max-product with different orderings
insert(expectedMPE)(0, 0)(1, 1)(2, 1); for (Ordering::OrderingType orderingType :
EXPECT(assert_equal(expectedMPE, actualMPE)); {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
Ordering::CUSTOM}) {
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
auto actualMPE2 = graph.optimize(); // all in one
EXPECT(assert_equal(mpe, actualMPE2));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) TEST(DiscreteFactorGraph, marginalIsNotMPE) {
{ // Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
// MPE does not have A=0.
DiscreteBayesNet bayesNet;
bayesNet.add(B | A = "1/1 1/2");
bayesNet.add(A % "10/9");
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// Which we verify using max-product:
DiscreteFactorGraph graph(bayesNet);
auto actualMPE = graph.optimize();
EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
auto notOptimal = bayesNet.optimize();
EXPECT(graph(notOptimal) < graph(mpe));
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
#endif
}
/* ************************************************************************* */
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
// The factor graph in Darwiche09book, page 244 // The factor graph in Darwiche09book, page 244
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
// Create Factor graph // Create Factor graph
DiscreteFactorGraph graph; DiscreteFactorGraph graph;
@ -206,53 +223,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(C & T1, "0.80 0.20 0.20 0.80");
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
//graph.product().print("Darwiche-product");
// graph.product().potentials().dot("Darwiche-product");
// DiscreteSequentialSolver(graph).eliminate()->print();
DiscreteValues expectedMPE; DiscreteValues mpe;
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
// You can check visually by printing product:
// graph.product().print("Darwiche-product");
// Use the solver machinery. // Check MPE.
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); auto actualMPE = graph.optimize();
auto actualMPE = chordal->optimize(); EXPECT(assert_equal(mpe, actualMPE));
EXPECT(assert_equal(expectedMPE, actualMPE));
// DiscreteConditional::shared_ptr root = chordal->back();
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
// Let us create the Bayes tree here, just for fun, because we don't use it now
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
//// bayesTree->print("Bayes Tree");
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
// Check Bayes Net
Ordering ordering; Ordering ordering;
ordering += Key(0),Key(1),Key(2),Key(3),Key(4); ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); auto chordal = graph.eliminateSequential(ordering);
// bayesTree->print("Bayes Tree"); EXPECT_LONGS_EQUAL(5, chordal->size());
EXPECT_LONGS_EQUAL(2,bayesTree->size()); #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
auto notOptimal = chordal->optimize(); // not MPE !
#ifdef OLD EXPECT(graph(notOptimal) < graph(mpe));
// Create the elimination tree manually
VariableIndexOrdered structure(graph);
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
ETree::shared_ptr eTree = ETree::Create(graph, structure);
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
// eliminate normally and check solution
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
auto actualMPE = optimize(*bayesNet);
EXPECT(assert_equal(expectedMPE, actualMPE));
// Approximate and check solution
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
// EXPECT(assert_equal(expectedMPE, *actualMPE));
#endif #endif
// Let us create the Bayes tree here, just for fun, because we don't use it
DiscreteBayesTree::shared_ptr bayesTree =
graph.eliminateMultifrontal(ordering);
// bayesTree->print("Bayes Tree");
EXPECT_LONGS_EQUAL(2, bayesTree->size());
} }
#ifdef OLD #ifdef OLD
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testDiscreteLookupDAG.cpp
*
* @date January, 2022
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <boost/assign/list_inserter.hpp>
#include <boost/assign/std/map.hpp>
using namespace gtsam;
using namespace boost::assign;
/* ************************************************************************* */
TEST(DiscreteLookupDAG, argmax) {
using ADT = AlgebraicDecisionTree<Key>;
// Declare 2 keys
DiscreteKey A(0, 2), B(1, 2);
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
DiscreteLookupDAG dag;
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
dag.add(1, DiscreteKeys{B, A}, adtB);
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
dag.add(1, DiscreteKeys{A}, adtA);
// The expected MPE is A=1, B=1
DiscreteValues mpe;
insert(mpe)(0, 1)(1, 1);
// check:
auto actualMPE = dag.argmax();
EXPECT(assert_equal(mpe, actualMPE));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -25,15 +25,12 @@
namespace gtsam { namespace gtsam {
/** /**
* TODO: Update comments. The following comments are out of date!!! * Base class for conditional densities. This class iterators and
*
* Base class for conditional densities, templated on KEY type. This class
* provides storage for the keys involved in a conditional, and iterators and
* access to the frontal and separator keys. * access to the frontal and separator keys.
* *
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
* to the associated factor type and shared_ptr type of the derived class. See * to the associated factor type and shared_ptr type of the derived class. See
* IndexConditional and GaussianConditional for examples. * SymbolicConditional and GaussianConditional for examples.
* \nosubgrouping * \nosubgrouping
*/ */
template<class FACTOR, class DERIVEDCONDITIONAL> template<class FACTOR, class DERIVEDCONDITIONAL>

View File

@ -14,18 +14,6 @@ using namespace std;
namespace gtsam { namespace gtsam {
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
return chordal->optimize();
}
/// Find the best total assignment - can be expensive
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
return chordal->optimize();
}
bool CSP::runArcConsistency(const VariableIndex& index, bool CSP::runArcConsistency(const VariableIndex& index,
Domains* domains) const { Domains* domains) const {
bool changed = false; bool changed = false;

View File

@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
// return result; // return result;
// } // }
/// Find the best total assignment - can be expensive.
DiscreteValues optimalAssignment() const;
/// Find the best total assignment, with given ordering - can be expensive.
DiscreteValues optimalAssignment(const Ordering& ordering) const;
// /* // /*
// * Perform loopy belief propagation // * Perform loopy belief propagation
// * True belief propagation would check for each value in domain // * True belief propagation would check for each value in domain

View File

@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
return chordal; return chordal;
} }
/** Find the best total assignment - can be expensive */
DiscreteValues Scheduler::optimalAssignment() const {
DiscreteBayesNet::shared_ptr chordal = eliminate();
if (ISDEBUG("Scheduler::optimalAssignment")) {
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
const Student& student = students_.front();
cout << endl;
(*it)->print(student.name_);
}
gttic(my_optimize);
DiscreteValues mpe = chordal->optimize();
gttoc(my_optimize);
return mpe;
}
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
DiscreteValues Scheduler::bestSchedule() const { DiscreteValues Scheduler::bestSchedule() const {
DiscreteValues best; DiscreteValues best;

View File

@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
/** Eliminate, return a Bayes net */ /** Eliminate, return a Bayes net */
DiscreteBayesNet::shared_ptr eliminate() const; DiscreteBayesNet::shared_ptr eliminate() const;
/** Find the best total assignment - can be expensive */
DiscreteValues optimalAssignment() const;
/** find the assignment of students to slots with most possible committees */ /** find the assignment of students to slots with most possible committees */
DiscreteValues bestSchedule() const; DiscreteValues bestSchedule() const;

View File

@ -122,7 +122,7 @@ void runLargeExample() {
// SETDEBUG("timing-verbose", true); // SETDEBUG("timing-verbose", true);
SETDEBUG("DiscreteConditional::DiscreteConditional", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true);
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(6 - s); DiscreteKey dkey = scheduler.studentKey(6 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);
@ -319,11 +319,11 @@ void accomodateStudent() {
// GTSAM_PRINT(*chordal); // GTSAM_PRINT(*chordal);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(0); DiscreteKey dkey = scheduler.studentKey(0);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)

View File

@ -143,7 +143,7 @@ void runLargeExample() {
} }
#else #else
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
size_t count = (*root)(values); size_t count = (*root)(values);

View File

@ -167,7 +167,7 @@ void runLargeExample() {
} }
#else #else
gttic(large); gttic(large);
auto MPE = scheduler.optimalAssignment(); auto MPE = scheduler.optimize();
gttoc(large); gttoc(large);
tictoc_finishedIteration(); tictoc_finishedIteration();
tictoc_print(); tictoc_print();
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
root->print(""/*scheduler.studentName(s)*/); root->print(""/*scheduler.studentName(s)*/);
// solve root node only // solve root node only
DiscreteValues values; size_t bestSlot = root->argmax();
size_t bestSlot = root->solve(values);
// get corresponding count // get corresponding count
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
DiscreteValues values;
values[dkey.first] = bestSlot; values[dkey.first] = bestSlot;
double count = (*root)(values); double count = (*root)(values);

View File

@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
EXPECT(assert_equal(expectedProduct, product)); EXPECT(assert_equal(expectedProduct, product));
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
csp.addAllDiff(WY, CO); csp.addAllDiff(WY, CO);
csp.addAllDiff(CO, NM); csp.addAllDiff(CO, NM);
DiscreteValues mpe;
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
// Create ordering according to example in ND-CSP.lyx // Create ordering according to example in ND-CSP.lyx
Ordering ordering; Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
Key(8), Key(9), Key(10); Key(8), Key(9), Key(10);
// Solve using that ordering:
auto mpe = csp.optimalAssignment(ordering);
// GTSAM_PRINT(mpe);
DiscreteValues expected;
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
UT.first, 1)(AZ.first, 0);
// TODO: Fix me! mpe result seems to be right. (See the printing) // Solve using that ordering:
// It has the same prob as the expected solution. auto actualMPE = csp.optimize(ordering);
// Is mpe another solution, or the expected solution is unique???
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(mpe, actualMPE));
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
// Write out the dual graph for hmetis // Write out the dual graph for hmetis
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
// Solve // Solve
auto mpe = csp.optimalAssignment(); auto mpe = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
EXPECT(assert_equal(expected, mpe)); EXPECT(assert_equal(expected, mpe));

View File

@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
// Do exact inference // Do exact inference
gttic(small); gttic(small);
auto MPE = s.optimalAssignment(); auto MPE = s.optimize();
gttoc(small); gttoc(small);
// print MPE, commented out as unit tests don't print // print MPE, commented out as unit tests don't print

View File

@ -100,7 +100,7 @@ class Sudoku : public CSP {
/// solve and print solution /// solve and print solution
void printSolution() const { void printSolution() const {
auto MPE = optimalAssignment(); auto MPE = optimize();
printAssignment(MPE); printAssignment(MPE);
} }
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
0, 1, 0, 0); 0, 1, 0, 0);
// optimize and check // optimize and check
auto solution = csp.optimalAssignment(); auto solution = csp.optimize();
DiscreteValues expected; DiscreteValues expected;
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
EXPECT_LONGS_EQUAL(16, new_csp.size()); EXPECT_LONGS_EQUAL(16, new_csp.size());
// Check that solution // Check that solution
auto new_solution = new_csp.optimalAssignment(); auto new_solution = new_csp.optimize();
// csp.printAssignment(new_solution); // csp.printAssignment(new_solution);
EXPECT(assert_equal(expected, new_solution)); EXPECT(assert_equal(expected, new_solution));
} }
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
EXPECT_LONGS_EQUAL(81, new_csp.size()); EXPECT_LONGS_EQUAL(81, new_csp.size());
// Check that solution // Check that solution
auto solution = new_csp.optimalAssignment(); auto solution = new_csp.optimize();
// csp.printAssignment(solution); // csp.printAssignment(solution);
EXPECT_LONGS_EQUAL(6, solution.at(key99)); EXPECT_LONGS_EQUAL(6, solution.at(key99));
} }

View File

@ -79,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
self.gtsamAssertEquals(chordal.at(7), expected2) self.gtsamAssertEquals(chordal.at(7), expected2)
# solve # solve
actualMPE = chordal.optimize() actualMPE = fg.optimize()
expectedMPE = DiscreteValues() expectedMPE = DiscreteValues()
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
expectedMPE[key[0]] = 0 expectedMPE[key[0]] = 0
@ -94,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
fg.add(Dyspnea, "0 1") fg.add(Dyspnea, "0 1")
# solve again, now with evidence # solve again, now with evidence
chordal2 = fg.eliminateSequential(ordering) actualMPE2 = fg.optimize()
actualMPE2 = chordal2.optimize()
expectedMPE2 = DiscreteValues() expectedMPE2 = DiscreteValues()
for key in [XRay, Tuberculosis, Either, LungCancer]: for key in [XRay, Tuberculosis, Either, LungCancer]:
expectedMPE2[key[0]] = 0 expectedMPE2[key[0]] = 0
@ -105,6 +104,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
list(expectedMPE2.items())) list(expectedMPE2.items()))
# now sample from it # now sample from it
chordal2 = fg.eliminateSequential(ordering)
actualSample = chordal2.sample() actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8) self.assertEqual(len(actualSample), 8)
@ -122,10 +122,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
for key in [Asia, Smoking]: for key in [Asia, Smoking]:
given[key[0]] = 0 given[key[0]] = 0
# Now optimize fragment:
actual = fragment.optimize(given)
self.assertEqual(len(actual), 5)
# Now sample from fragment: # Now sample from fragment:
actual = fragment.sample(given) actual = fragment.sample(given)
self.assertEqual(len(actual), 5) self.assertEqual(len(actual), 5)

View File

@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
X = 0, 2 X = 0, 2
class TestDiscretePrior(GtsamTestCase): class TestDiscreteDistribution(GtsamTestCase):
"""Tests for Discrete Priors.""" """Tests for Discrete Priors."""
def test_constructor(self): def test_constructor(self):