Merge pull request #1050 from borglab/feature/betterMPE
commit
b441eea976
|
@ -53,10 +53,9 @@ int main(int argc, char **argv) {
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
|
||||||
|
|
||||||
// solve
|
// solve
|
||||||
auto mpe = chordal->optimize();
|
auto mpe = fg.optimize();
|
||||||
GTSAM_PRINT(mpe);
|
GTSAM_PRINT(mpe);
|
||||||
|
|
||||||
// We can also build a Bayes tree (directed junction tree).
|
// We can also build a Bayes tree (directed junction tree).
|
||||||
|
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
auto mpe2 = fg.optimize();
|
||||||
auto mpe2 = chordal2->optimize();
|
|
||||||
GTSAM_PRINT(mpe2);
|
GTSAM_PRINT(mpe2);
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from it
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < 10; i++) {
|
||||||
auto sample = chordal2->sample();
|
auto sample = chordal->sample();
|
||||||
GTSAM_PRINT(sample);
|
GTSAM_PRINT(sample);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Most Probable Explanation", i.e., configuration with largest value
|
// "Most Probable Explanation", i.e., configuration with largest value
|
||||||
auto mpe = graph.eliminateSequential()->optimize();
|
auto mpe = graph.optimize();
|
||||||
cout << "\nMost Probable Explanation (MPE):" << endl;
|
cout << "\nMost Probable Explanation (MPE):" << endl;
|
||||||
print(mpe);
|
print(mpe);
|
||||||
|
|
||||||
|
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
|
||||||
graph.add(Cloudy, "1 0");
|
graph.add(Cloudy, "1 0");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto mpe_with_evidence = graph.optimize();
|
||||||
auto mpe_with_evidence = chordal->optimize();
|
|
||||||
|
|
||||||
cout << "\nMPE given C=0:" << endl;
|
cout << "\nMPE given C=0:" << endl;
|
||||||
print(mpe_with_evidence);
|
print(mpe_with_evidence);
|
||||||
|
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
|
||||||
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
||||||
<< endl;
|
<< endl;
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from the eliminated graph
|
||||||
|
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t i = 0; i < 10; i++) {
|
for (size_t i = 0; i < 10; i++) {
|
||||||
auto sample = chordal->sample();
|
auto sample = chordal->sample();
|
||||||
|
|
|
@ -59,16 +59,16 @@ int main(int argc, char **argv) {
|
||||||
// Convert to factor graph
|
// Convert to factor graph
|
||||||
DiscreteFactorGraph factorGraph(hmm);
|
DiscreteFactorGraph factorGraph(hmm);
|
||||||
|
|
||||||
|
// Do max-prodcut
|
||||||
|
auto mpe = factorGraph.optimize();
|
||||||
|
GTSAM_PRINT(mpe);
|
||||||
|
|
||||||
// Create solver and eliminate
|
// Create solver and eliminate
|
||||||
// This will create a DAG ordered with arrow of time reversed
|
// This will create a DAG ordered with arrow of time reversed
|
||||||
DiscreteBayesNet::shared_ptr chordal =
|
DiscreteBayesNet::shared_ptr chordal =
|
||||||
factorGraph.eliminateSequential(ordering);
|
factorGraph.eliminateSequential(ordering);
|
||||||
chordal->print("Eliminated");
|
chordal->print("Eliminated");
|
||||||
|
|
||||||
// solve
|
|
||||||
auto mpe = chordal->optimize();
|
|
||||||
GTSAM_PRINT(mpe);
|
|
||||||
|
|
||||||
// We can also sample from it
|
// We can also sample from it
|
||||||
cout << "\n10 samples:" << endl;
|
cout << "\n10 samples:" << endl;
|
||||||
for (size_t k = 0; k < 10; k++) {
|
for (size_t k = 0; k < 10; k++) {
|
||||||
|
|
|
@ -68,9 +68,8 @@ int main(int argc, char** argv) {
|
||||||
<< graph.size() << " factors (Unary+Edge).";
|
<< graph.size() << " factors (Unary+Edge).";
|
||||||
|
|
||||||
// "Decoding", i.e., configuration with largest value
|
// "Decoding", i.e., configuration with largest value
|
||||||
// We use sequential variable elimination
|
// Uses max-product.
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto optimalDecoding = graph.optimize();
|
||||||
auto optimalDecoding = chordal->optimize();
|
|
||||||
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||||
|
|
||||||
// "Inference" Computing marginals for each node
|
// "Inference" Computing marginals for each node
|
||||||
|
|
|
@ -61,9 +61,8 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Decoding", i.e., configuration with largest value (MPE)
|
// "Decoding", i.e., configuration with largest value (MPE)
|
||||||
// We use sequential variable elimination
|
// Uses max-product
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto optimalDecoding = graph.optimize();
|
||||||
auto optimalDecoding = chordal->optimize();
|
|
||||||
GTSAM_PRINT(optimalDecoding);
|
GTSAM_PRINT(optimalDecoding);
|
||||||
|
|
||||||
// "Inference" Computing marginals
|
// "Inference" Computing marginals
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
#include <boost/format.hpp>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -65,9 +66,13 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void DecisionTreeFactor::print(const string& s,
|
void DecisionTreeFactor::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s;
|
cout << s;
|
||||||
ADT::print("Potentials:",formatter);
|
cout << " f[";
|
||||||
|
for (auto&& key : keys())
|
||||||
|
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||||
|
cout << " ]" << endl;
|
||||||
|
ADT::print("Potentials:", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -127,11 +127,16 @@ namespace gtsam {
|
||||||
return combine(keys, ADT::Ring::add);
|
return combine(keys, ADT::Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator values
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
shared_ptr max(size_t nrFrontals) const {
|
shared_ptr max(size_t nrFrontals) const {
|
||||||
return combine(nrFrontals, ADT::Ring::max);
|
return combine(nrFrontals, ADT::Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
|
shared_ptr max(const Ordering& keys) const {
|
||||||
|
return combine(keys, ADT::Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||||
DiscreteValues result;
|
DiscreteValues result;
|
||||||
return optimize(result);
|
return optimize(result);
|
||||||
|
@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const {
|
||||||
|
|
||||||
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||||
// solve each node in turn in topological sort order (parents first)
|
// solve each node in turn in topological sort order (parents first)
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
|
||||||
|
#else
|
||||||
|
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
|
||||||
|
#endif
|
||||||
for (auto conditional : boost::adaptors::reverse(*this))
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
conditional->solveInPlace(&result);
|
conditional->solveInPlace(&result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteValues DiscreteBayesNet::sample() const {
|
DiscreteValues DiscreteBayesNet::sample() const {
|
||||||
|
|
|
@ -31,12 +31,12 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** A Bayes net made from linear-Discrete densities */
|
/** A Bayes net made from discrete conditional distributions. */
|
||||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef FactorGraph<DiscreteConditional> Base;
|
typedef BayesNet<DiscreteConditional> Base;
|
||||||
typedef DiscreteBayesNet This;
|
typedef DiscreteBayesNet This;
|
||||||
typedef DiscreteConditional ConditionalType;
|
typedef DiscreteConditional ConditionalType;
|
||||||
typedef boost::shared_ptr<This> shared_ptr;
|
typedef boost::shared_ptr<This> shared_ptr;
|
||||||
|
@ -45,7 +45,7 @@ namespace gtsam {
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Construct empty factor graph */
|
/// Construct empty Bayes net.
|
||||||
DiscreteBayesNet() {}
|
DiscreteBayesNet() {}
|
||||||
|
|
||||||
/** Construct from iterator over conditionals */
|
/** Construct from iterator over conditionals */
|
||||||
|
@ -98,27 +98,6 @@ namespace gtsam {
|
||||||
return evaluate(values);
|
return evaluate(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief solve by back-substitution.
|
|
||||||
*
|
|
||||||
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
|
||||||
* conditional will be optimized first. If the Bayes net resulted from
|
|
||||||
* eliminating a factor graph, this is true for the elimination ordering.
|
|
||||||
*
|
|
||||||
* @return a sampled value for all variables.
|
|
||||||
*/
|
|
||||||
DiscreteValues optimize() const;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief solve by back-substitution, given certain variables.
|
|
||||||
*
|
|
||||||
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
|
||||||
* Bayes net does not contain any conditionals for the given values.
|
|
||||||
*
|
|
||||||
* @return given values extended with optimized value for other variables.
|
|
||||||
*/
|
|
||||||
DiscreteValues optimize(DiscreteValues given) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief do ancestral sampling
|
* @brief do ancestral sampling
|
||||||
*
|
*
|
||||||
|
@ -152,7 +131,16 @@ namespace gtsam {
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const DiscreteFactor::Names& names = {}) const;
|
const DiscreteFactor::Names& names = {}) const;
|
||||||
|
|
||||||
|
///@}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
DiscreteValues GTSAM_DEPRECATED optimize() const;
|
||||||
|
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
|
|
@ -16,26 +16,25 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/debug.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <set>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using std::pair;
|
||||||
using std::stringstream;
|
using std::stringstream;
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using std::pair;
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
|
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||||
const DiscreteValues& given,
|
const DiscreteValues& given, bool forceComplete) const {
|
||||||
bool forceComplete = true) {
|
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
// branches based on the value of the parent variables.
|
// branches based on the value of the parent variables.
|
||||||
DiscreteConditional::ADT adt(conditional);
|
DiscreteConditional::ADT adt(*this);
|
||||||
size_t value;
|
size_t value;
|
||||||
for (Key j : conditional.parents()) {
|
for (Key j : parents()) {
|
||||||
try {
|
try {
|
||||||
value = given.at(j);
|
value = given.at(j);
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
|
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
if (forceComplete) {
|
if (forceComplete) {
|
||||||
given.print("parentsValues: ");
|
given.print("parentsValues: ");
|
||||||
throw runtime_error(
|
throw runtime_error(
|
||||||
"DiscreteConditional::Choose: parent value missing");
|
"DiscreteConditional::choose: parent value missing");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||||
const DiscreteValues& given) const {
|
const DiscreteValues& given) const {
|
||||||
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
ADT adt = choose(given, false); // P(F|S=given)
|
||||||
|
|
||||||
// Collect all keys not in given.
|
// Collect all keys not in given.
|
||||||
DiscreteKeys dKeys;
|
DiscreteKeys dKeys;
|
||||||
|
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ****************************************************************************/
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
size_t parent_value) const {
|
size_t parent_value) const {
|
||||||
if (nrFrontals() != 1)
|
if (nrFrontals() != 1)
|
||||||
|
@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
@ -248,59 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// Get all Possible Configurations
|
// Get all Possible Configurations
|
||||||
const auto allPosbValues = frontalAssignments();
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
// Find the MPE
|
// Find the maximum
|
||||||
for (const auto& frontalVals : allPosbValues) {
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update maximum solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = frontalVals;
|
mpe = frontalVals;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// set values (inPlace) to mpe
|
// set values (inPlace) to maximum
|
||||||
for (Key j : frontals()) {
|
for (Key j : frontals()) {
|
||||||
(*values)[j] = mpe[j];
|
(*values)[j] = mpe[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
|
||||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|
||||||
assert(nrFrontals() == 1);
|
|
||||||
Key j = (firstFrontalKey());
|
|
||||||
size_t sampled = sample(*values); // Sample variable given parents
|
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
size_t max = 0;
|
||||||
size_t mpe = 0;
|
|
||||||
DiscreteValues frontals;
|
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
frontals[j] = value;
|
frontals[j] = value;
|
||||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
max = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteConditional::argmax() const {
|
||||||
|
size_t maxValue = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
assert(nrParents() == 0);
|
||||||
|
DiscreteValues frontals;
|
||||||
|
Key j = firstFrontalKey();
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = (*this)(frontals);
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = value;
|
maxValue = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mpe;
|
return maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
size_t sampled = sample(*values); // Sample variable given parents
|
||||||
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
static mt19937 rng(2); // random number generator
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
|
@ -323,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
return distribution(rng);
|
return distribution(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
if (nrParents() != 1)
|
if (nrParents() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -334,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
return sample(values);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::sample() const {
|
size_t DiscreteConditional::sample() const {
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
|
|
@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
*/
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
|
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief restrict to given *parent* values.
|
* @brief restrict to given *parent* values.
|
||||||
*
|
*
|
||||||
* Note: does not need be complete set. Examples:
|
* Note: does not need be complete set. Examples:
|
||||||
*
|
*
|
||||||
* P(C|D,E) + . -> P(C|D,E)
|
* P(C|D,E) + . -> P(C|D,E)
|
||||||
* P(C|D,E) + E -> P(C|D)
|
* P(C|D,E) + E -> P(C|D)
|
||||||
* P(C|D,E) + D -> P(C|E)
|
* P(C|D,E) + D -> P(C|E)
|
||||||
* P(C|D,E) + D,E -> P(C)
|
* P(C|D,E) + D,E -> P(C)
|
||||||
* P(C|D,E) + C -> error!
|
* P(C|D,E) + C -> error!
|
||||||
*
|
*
|
||||||
* @return a shared_ptr to a new DiscreteConditional
|
* @return a shared_ptr to a new DiscreteConditional
|
||||||
*/
|
*/
|
||||||
shared_ptr choose(const DiscreteValues& given) const;
|
shared_ptr choose(const DiscreteValues& given) const;
|
||||||
|
@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/** Single variable version of likelihood. */
|
/** Single variable version of likelihood. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* solve a conditional
|
|
||||||
* @param parentsValues Known values of the parents
|
|
||||||
* @return MPE value of the child (1 frontal variable).
|
|
||||||
*/
|
|
||||||
size_t solve(const DiscreteValues& parentsValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* sample
|
* sample
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
|
@ -199,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// Zero parent version.
|
/// Zero parent version.
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return assignment that maximizes distribution.
|
||||||
|
* @return Optimal assignment (1 frontal variable).
|
||||||
|
*/
|
||||||
|
size_t argmax() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// solve a conditional, in place
|
|
||||||
void solveInPlace(DiscreteValues* parentsValues) const;
|
|
||||||
|
|
||||||
/// sample in place, stores result in partial solution
|
/// sample in place, stores result in partial solution
|
||||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
|
@ -228,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
|
||||||
|
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Internal version of choose
|
||||||
|
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||||
|
bool forceComplete) const;
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
|
|
|
@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
||||||
/// Return entire probability mass function.
|
/// Return entire probability mass function.
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
|
|
||||||
/**
|
|
||||||
* solve a conditional
|
|
||||||
* @return MPE value of the child (1 frontal variable).
|
|
||||||
*/
|
|
||||||
size_t solve() const { return Base::solve({}); }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* sample
|
|
||||||
* @return sample from conditional
|
|
||||||
*/
|
|
||||||
size_t sample() const { return Base::sample(); }
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
/// @name Deprecated functionality
|
||||||
|
/// @{
|
||||||
|
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
|
||||||
|
/// @}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
// DiscreteDistribution
|
// DiscreteDistribution
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||||
#include <gtsam/inference/FactorGraph-inst.h>
|
#include <gtsam/inference/FactorGraph-inst.h>
|
||||||
|
|
||||||
|
@ -95,22 +96,85 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DiscreteValues DiscreteFactorGraph::optimize() const
|
// Alternate eliminate function for MPE
|
||||||
{
|
|
||||||
gttic(DiscreteFactorGraph_optimize);
|
|
||||||
return BaseEliminateable::eliminateSequential()->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
DecisionTreeFactor product;
|
DecisionTreeFactor product;
|
||||||
for(const DiscreteFactor::shared_ptr& factor: factors)
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
product = (*factor) * product;
|
gttoc(product);
|
||||||
|
|
||||||
|
// max out frontals, this is the factor on the separator
|
||||||
|
gttic(max);
|
||||||
|
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||||
|
gttoc(max);
|
||||||
|
|
||||||
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
|
DiscreteKeys orderedKeys;
|
||||||
|
for (auto&& key : frontalKeys)
|
||||||
|
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||||
|
for (auto&& key : max->keys())
|
||||||
|
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||||
|
|
||||||
|
// Make lookup with product
|
||||||
|
gttic(lookup);
|
||||||
|
size_t nrFrontals = frontalKeys.size();
|
||||||
|
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||||
|
orderedKeys, product);
|
||||||
|
gttoc(lookup);
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// The max-product solution below is a bit clunky: the elimination machinery
|
||||||
|
// does not allow for differently *typed* versions of elimination, so we
|
||||||
|
// eliminate into a Bayes Net using the special eliminate function above, and
|
||||||
|
// then create the DiscreteLookupDAG after the fact, in linear time.
|
||||||
|
|
||||||
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
|
auto bayesNet =
|
||||||
|
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
|
||||||
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_maxProduct);
|
||||||
|
auto bayesNet =
|
||||||
|
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
|
||||||
|
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
|
DiscreteLookupDAG dag = maxProduct(orderingType);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteFactorGraph::optimize(
|
||||||
|
const Ordering& ordering) const {
|
||||||
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
|
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||||
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
|
const Ordering& frontalKeys) {
|
||||||
|
// PRODUCT: multiply all factors
|
||||||
|
gttic(product);
|
||||||
|
DecisionTreeFactor product;
|
||||||
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
|
@ -120,15 +184,18 @@ namespace gtsam {
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
Ordering orderedKeys;
|
Ordering orderedKeys;
|
||||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
frontalKeys.end());
|
||||||
|
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
|
||||||
|
sum->keys().end());
|
||||||
|
|
||||||
// now divide product/sum to get conditional
|
// now divide product/sum to get conditional
|
||||||
gttic(divide);
|
gttic(divide);
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
auto conditional =
|
||||||
|
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||||
gttoc(divide);
|
gttoc(divide);
|
||||||
|
|
||||||
return std::make_pair(cond, sum);
|
return std::make_pair(conditional, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -18,10 +18,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
|
||||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
|
@ -128,18 +129,39 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
||||||
const std::string& s = "DiscreteFactorGraph",
|
const std::string& s = "DiscreteFactorGraph",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
/**
|
||||||
* the dense elimination function specified in \c function,
|
* @brief Implement the max-product algorithm
|
||||||
* followed by back-substitution resulting from elimination. Is equivalent
|
*
|
||||||
* to calling graph.eliminateSequential()->optimize(). */
|
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||||
DiscreteValues optimize() const;
|
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
|
||||||
|
*/
|
||||||
|
DiscreteLookupDAG maxProduct(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Implement the max-product algorithm
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables
|
||||||
|
*/
|
||||||
|
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||||
|
|
||||||
// /** Permute the variables in the factors */
|
/**
|
||||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||||
//
|
*
|
||||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
* @param orderingType
|
||||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
* @return DiscreteValues : MPE
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||||
|
*
|
||||||
|
* @param ordering
|
||||||
|
* @return DiscreteValues : MPE
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(const Ordering& ordering) const;
|
||||||
|
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
||||||
|
|
||||||
KeyVector DiscreteKeys::indices() const {
|
KeyVector DiscreteKeys::indices() const {
|
||||||
KeyVector js;
|
KeyVector js;
|
||||||
for(const DiscreteKey& key: *this)
|
for (const DiscreteKey& key : *this) js.push_back(key.first);
|
||||||
js.push_back(key.first);
|
|
||||||
return js;
|
return js;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<Key,size_t> DiscreteKeys::cardinalities() const {
|
map<Key, size_t> DiscreteKeys::cardinalities() const {
|
||||||
map<Key,size_t> cs;
|
map<Key, size_t> cs;
|
||||||
cs.insert(begin(),end());
|
cs.insert(begin(), end());
|
||||||
// for(const DiscreteKey& key: *this)
|
|
||||||
// cs.insert(key);
|
|
||||||
return cs;
|
return cs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,127 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.cpp
|
||||||
|
* @date Feb 14, 2011
|
||||||
|
* @author Duy-Nguyen Ta
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using std::pair;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||||
|
void DiscreteLookupTable::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
|
cout << s << " g( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "; ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
ADT::print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
DiscreteValues mpe;
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
// Get all Possible Configurations
|
||||||
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
|
// Find the maximum
|
||||||
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update maximum solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set values (inPlace) to maximum
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
(*values)[j] = mpe[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Then, find the max over all remaining
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
size_t mpe = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = (firstFrontalKey());
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
frontals[j] = value;
|
||||||
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mpe;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||||
|
const DiscreteBayesNet& bayesNet) {
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
for (auto&& conditional : bayesNet) {
|
||||||
|
if (auto lookupTable =
|
||||||
|
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||||
|
dag.push_back(lookupTable);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dag;
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||||
|
// Argmax each node in turn in topological sort order (parents first).
|
||||||
|
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||||
|
lookupTable->argmaxInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -0,0 +1,140 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.h
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
class DiscreteBayesNet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief DiscreteLookupTable table for max-product
|
||||||
|
*
|
||||||
|
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||||
|
* Is used in the max-product algorithm.
|
||||||
|
*/
|
||||||
|
class DiscreteLookupTable : public DiscreteConditional {
|
||||||
|
public:
|
||||||
|
using This = DiscreteLookupTable;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Discrete Lookup Table object
|
||||||
|
*
|
||||||
|
* @param nFrontals number of frontal variables
|
||||||
|
* @param keys a orted list of gtsam::Keys
|
||||||
|
* @param potentials the algebraic decision tree with lookup values
|
||||||
|
*/
|
||||||
|
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials)
|
||||||
|
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Lookup Table: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief return assignment for single frontal variable that maximizes value.
|
||||||
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
|
*/
|
||||||
|
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** A DAG made from lookup tables, as defined above. */
|
||||||
|
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||||
|
public:
|
||||||
|
using Base = BayesNet<DiscreteLookupTable>;
|
||||||
|
using This = DiscreteLookupDAG;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Construct empty DAG.
|
||||||
|
DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// Create from BayesNet with LookupTables
|
||||||
|
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
virtual ~DiscreteLookupDAG() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Check equality */
|
||||||
|
bool equals(const This& bn, double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Add a DiscreteLookupTable */
|
||||||
|
template <typename... Args>
|
||||||
|
void add(Args&&... args) {
|
||||||
|
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief argmax by back-substitution, optionally given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be optimized first *and* that the
|
||||||
|
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||||
|
* resulted from eliminating a factor graph, this is true for the elimination
|
||||||
|
* ordering.
|
||||||
|
*
|
||||||
|
* @return given assignment extended w. optimal assignment for all variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
|
@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
const gtsam::DiscreteValues& frontalValues) const;
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
double operator()(size_t value) const;
|
double operator()(size_t value) const;
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
size_t solve() const;
|
size_t argmax() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
@ -163,8 +161,6 @@ class DiscreteBayesNet {
|
||||||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
|
||||||
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
|
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
@ -217,6 +213,21 @@ class DiscreteBayesTree {
|
||||||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
class DiscreteLookupDAG {
|
||||||
|
DiscreteLookupDAG();
|
||||||
|
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||||
|
bool empty() const;
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::DiscreteLookupTable* at(size_t i) const;
|
||||||
|
void print(string s = "DiscreteLookupDAG\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
gtsam::DiscreteValues argmax() const;
|
||||||
|
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||||
|
};
|
||||||
|
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
class DotWriter {
|
class DotWriter {
|
||||||
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
||||||
|
@ -260,6 +271,9 @@ class DiscreteFactorGraph {
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct();
|
||||||
|
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||||
|
|
||||||
gtsam::DiscreteBayesNet eliminateSequential();
|
gtsam::DiscreteBayesNet eliminateSequential();
|
||||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||||
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
|
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
|
||||||
|
|
|
@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
|
||||||
DiscreteConditional expected2(Bronchitis % "11/9");
|
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
|
|
||||||
// solve
|
|
||||||
auto actualMPE = chordal->optimize();
|
|
||||||
DiscreteValues expectedMPE;
|
|
||||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
|
||||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
|
||||||
|
|
||||||
// add evidence, we were in Asia and we have dyspnea
|
// add evidence, we were in Asia and we have dyspnea
|
||||||
fg.add(Asia, "0 1");
|
fg.add(Asia, "0 1");
|
||||||
fg.add(Dyspnea, "0 1");
|
fg.add(Dyspnea, "0 1");
|
||||||
|
|
||||||
// solve again, now with evidence
|
// solve again, now with evidence
|
||||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||||
auto actualMPE2 = chordal2->optimize();
|
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||||
DiscreteValues expectedMPE2;
|
|
||||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
|
||||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
|
||||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
|
||||||
EXPECT(assert_equal(expectedMPE2, actualMPE2));
|
|
||||||
|
|
||||||
// now sample from it
|
// now sample from it
|
||||||
DiscreteValues expectedSample;
|
DiscreteValues expectedSample;
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDiscretePrior.cpp
|
* @file testDiscreteDistribution.cpp
|
||||||
* @brief unit tests for DiscreteDistribution
|
* @brief unit tests for DiscreteDistribution
|
||||||
* @author Frank dellaert
|
* @author Frank dellaert
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
|
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
|
||||||
prior.sample();
|
prior.sample();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteDistribution, argmax) {
|
||||||
|
DiscreteDistribution prior(X % "2/3");
|
||||||
|
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -30,8 +30,8 @@ using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) {
|
||||||
DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3);
|
DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3);
|
||||||
|
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(AI, "1 0 0 1");
|
graph.add(AI, "1 0 0 1");
|
||||||
|
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
||||||
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||||
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||||
|
|
||||||
// graph.print("Graph: ");
|
// Check MPE.
|
||||||
DecisionTreeFactor product = graph.product();
|
auto actualMPE = graph.optimize();
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(1);
|
DiscreteValues mpe;
|
||||||
// sum->print("Debug SUM: ");
|
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
|
||||||
// cond->print("marginal:");
|
|
||||||
|
|
||||||
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
|
|
||||||
// result.first->print("BayesNet: ");
|
|
||||||
// result.second->print("New factor: ");
|
|
||||||
//
|
|
||||||
Ordering ordering;
|
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
|
||||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
|
||||||
// eliminationTree.print("Elimination tree: ");
|
|
||||||
eliminationTree.eliminate(EliminateDiscrete);
|
|
||||||
// solver.optimize();
|
|
||||||
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, test)
|
TEST(DiscreteFactorGraph, test) {
|
||||||
{
|
|
||||||
// Declare keys and ordering
|
// Declare keys and ordering
|
||||||
DiscreteKey C(0,2), B(1,2), A(2,2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
|
|
||||||
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
// A simple factor graph (A)-fAC-(C)-fBC-(B)
|
||||||
// with smoothness priors
|
// with smoothness priors
|
||||||
|
@ -127,77 +112,109 @@ TEST( DiscreteFactorGraph, test)
|
||||||
graph.add(C & B, "3 1 1 3");
|
graph.add(C & B, "3 1 1 3");
|
||||||
|
|
||||||
// Test EliminateDiscrete
|
// Test EliminateDiscrete
|
||||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
|
||||||
Ordering frontalKeys;
|
Ordering frontalKeys;
|
||||||
frontalKeys += Key(0);
|
frontalKeys += Key(0);
|
||||||
DiscreteConditional::shared_ptr conditional;
|
DiscreteConditional::shared_ptr conditional;
|
||||||
DecisionTreeFactor::shared_ptr newFactor;
|
DecisionTreeFactor::shared_ptr newFactor;
|
||||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
// Check Bayes net
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
DiscreteBayesNet expected;
|
|
||||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||||
// cout << signature << endl;
|
|
||||||
DiscreteConditional expectedConditional(signature);
|
DiscreteConditional expectedConditional(signature);
|
||||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||||
expected.add(signature);
|
|
||||||
|
|
||||||
// Check Factor
|
// Check Factor
|
||||||
CHECK(newFactor);
|
CHECK(newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||||
|
|
||||||
// add conditionals to complete expected Bayes net
|
// Test using elimination tree
|
||||||
expected.add(B | A = "5/3 3/5");
|
|
||||||
expected.add(A % "1/1");
|
|
||||||
// GTSAM_PRINT(expected);
|
|
||||||
|
|
||||||
// Test elimination tree
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2);
|
ordering += Key(0), Key(1), Key(2);
|
||||||
DiscreteEliminationTree etree(graph, ordering);
|
DiscreteEliminationTree etree(graph, ordering);
|
||||||
DiscreteBayesNet::shared_ptr actual;
|
DiscreteBayesNet::shared_ptr actual;
|
||||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||||
EXPECT(assert_equal(expected, *actual));
|
|
||||||
|
|
||||||
// // Test solver
|
// Check Bayes net
|
||||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
DiscreteBayesNet expectedBayesNet;
|
||||||
// EXPECT(assert_equal(expected, *actual2));
|
expectedBayesNet.add(signature);
|
||||||
|
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||||
|
expectedBayesNet.add(A % "1/1");
|
||||||
|
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||||
|
|
||||||
// Test optimization
|
// Test eliminateSequential
|
||||||
DiscreteValues expectedValues;
|
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||||
auto actualValues = graph.optimize();
|
|
||||||
EXPECT(assert_equal(expectedValues, actualValues));
|
// Test mpe
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||||
|
auto actualMPE = graph.optimize();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, testMPE)
|
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||||
{
|
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey C(0,2), A(1,2), B(2,2);
|
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||||
|
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||||
// graph.product().print();
|
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
|
||||||
|
|
||||||
auto actualMPE = graph.optimize();
|
// Created expected MPE
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||||
|
|
||||||
DiscreteValues expectedMPE;
|
// Do max-product with different orderings
|
||||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
for (Ordering::OrderingType orderingType :
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||||
|
Ordering::CUSTOM}) {
|
||||||
|
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
auto actualMPE2 = graph.optimize(); // all in one
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||||
{
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||||
|
// MPE does not have A=0.
|
||||||
|
DiscreteBayesNet bayesNet;
|
||||||
|
bayesNet.add(B | A = "1/1 1/2");
|
||||||
|
bayesNet.add(A % "10/9");
|
||||||
|
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 1)(1, 1);
|
||||||
|
|
||||||
|
// Which we verify using max-product:
|
||||||
|
DiscreteFactorGraph graph(bayesNet);
|
||||||
|
auto actualMPE = graph.optimize();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||||
|
auto notOptimal = bayesNet.optimize();
|
||||||
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||||
// The factor graph in Darwiche09book, page 244
|
// The factor graph in Darwiche09book, page 244
|
||||||
DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2);
|
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
|
||||||
|
|
||||||
// Create Factor graph
|
// Create Factor graph
|
||||||
DiscreteFactorGraph graph;
|
DiscreteFactorGraph graph;
|
||||||
|
@ -206,53 +223,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||||
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
graph.add(C & T1, "0.80 0.20 0.20 0.80");
|
||||||
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
||||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||||
graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche)
|
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||||
//graph.product().print("Darwiche-product");
|
|
||||||
// graph.product().potentials().dot("Darwiche-product");
|
|
||||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
|
||||||
|
|
||||||
DiscreteValues expectedMPE;
|
DiscreteValues mpe;
|
||||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||||
|
// You can check visually by printing product:
|
||||||
|
// graph.product().print("Darwiche-product");
|
||||||
|
|
||||||
// Use the solver machinery.
|
// Check MPE.
|
||||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
auto actualMPE = graph.optimize();
|
||||||
auto actualMPE = chordal->optimize();
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
|
||||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
|
||||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
|
||||||
|
|
||||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
|
||||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
|
||||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
|
||||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
|
||||||
//// bayesTree->print("Bayes Tree");
|
|
||||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
|
||||||
|
|
||||||
|
// Check Bayes Net
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0),Key(1),Key(2),Key(3),Key(4);
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
auto chordal = graph.eliminateSequential(ordering);
|
||||||
// bayesTree->print("Bayes Tree");
|
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||||
EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
|
auto notOptimal = chordal->optimize(); // not MPE !
|
||||||
#ifdef OLD
|
EXPECT(graph(notOptimal) < graph(mpe));
|
||||||
// Create the elimination tree manually
|
|
||||||
VariableIndexOrdered structure(graph);
|
|
||||||
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
|
|
||||||
ETree::shared_ptr eTree = ETree::Create(graph, structure);
|
|
||||||
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
|
|
||||||
|
|
||||||
// eliminate normally and check solution
|
|
||||||
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
|
|
||||||
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
|
|
||||||
auto actualMPE = optimize(*bayesNet);
|
|
||||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
|
||||||
|
|
||||||
// Approximate and check solution
|
|
||||||
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
|
|
||||||
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
|
|
||||||
// EXPECT(assert_equal(expectedMPE, *actualMPE));
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||||
|
DiscreteBayesTree::shared_ptr bayesTree =
|
||||||
|
graph.eliminateMultifrontal(ordering);
|
||||||
|
// bayesTree->print("Bayes Tree");
|
||||||
|
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef OLD
|
#ifdef OLD
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testDiscreteLookupDAG.cpp
|
||||||
|
*
|
||||||
|
* @date January, 2022
|
||||||
|
* @author Frank Dellaert
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
|
||||||
|
#include <boost/assign/list_inserter.hpp>
|
||||||
|
#include <boost/assign/std/map.hpp>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
using namespace boost::assign;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscreteLookupDAG, argmax) {
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
|
||||||
|
// Declare 2 keys
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
|
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||||
|
DiscreteLookupDAG dag;
|
||||||
|
|
||||||
|
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||||
|
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||||
|
|
||||||
|
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||||
|
dag.add(1, DiscreteKeys{A}, adtA);
|
||||||
|
|
||||||
|
// The expected MPE is A=1, B=1
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 1)(1, 1);
|
||||||
|
|
||||||
|
// check:
|
||||||
|
auto actualMPE = dag.argmax();
|
||||||
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
|
@ -25,15 +25,12 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO: Update comments. The following comments are out of date!!!
|
* Base class for conditional densities. This class iterators and
|
||||||
*
|
|
||||||
* Base class for conditional densities, templated on KEY type. This class
|
|
||||||
* provides storage for the keys involved in a conditional, and iterators and
|
|
||||||
* access to the frontal and separator keys.
|
* access to the frontal and separator keys.
|
||||||
*
|
*
|
||||||
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
|
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
|
||||||
* to the associated factor type and shared_ptr type of the derived class. See
|
* to the associated factor type and shared_ptr type of the derived class. See
|
||||||
* IndexConditional and GaussianConditional for examples.
|
* SymbolicConditional and GaussianConditional for examples.
|
||||||
* \nosubgrouping
|
* \nosubgrouping
|
||||||
*/
|
*/
|
||||||
template<class FACTOR, class DERIVEDCONDITIONAL>
|
template<class FACTOR, class DERIVEDCONDITIONAL>
|
||||||
|
|
|
@ -14,18 +14,6 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
|
||||||
DiscreteValues CSP::optimalAssignment() const {
|
|
||||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
|
|
||||||
return chordal->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive
|
|
||||||
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
|
|
||||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
|
|
||||||
return chordal->optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool CSP::runArcConsistency(const VariableIndex& index,
|
bool CSP::runArcConsistency(const VariableIndex& index,
|
||||||
Domains* domains) const {
|
Domains* domains) const {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
|
@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
|
||||||
// return result;
|
// return result;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/// Find the best total assignment - can be expensive.
|
|
||||||
DiscreteValues optimalAssignment() const;
|
|
||||||
|
|
||||||
/// Find the best total assignment, with given ordering - can be expensive.
|
|
||||||
DiscreteValues optimalAssignment(const Ordering& ordering) const;
|
|
||||||
|
|
||||||
// /*
|
// /*
|
||||||
// * Perform loopy belief propagation
|
// * Perform loopy belief propagation
|
||||||
// * True belief propagation would check for each value in domain
|
// * True belief propagation would check for each value in domain
|
||||||
|
|
|
@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
|
||||||
return chordal;
|
return chordal;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Find the best total assignment - can be expensive */
|
|
||||||
DiscreteValues Scheduler::optimalAssignment() const {
|
|
||||||
DiscreteBayesNet::shared_ptr chordal = eliminate();
|
|
||||||
|
|
||||||
if (ISDEBUG("Scheduler::optimalAssignment")) {
|
|
||||||
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
|
|
||||||
const Student& student = students_.front();
|
|
||||||
cout << endl;
|
|
||||||
(*it)->print(student.name_);
|
|
||||||
}
|
|
||||||
|
|
||||||
gttic(my_optimize);
|
|
||||||
DiscreteValues mpe = chordal->optimize();
|
|
||||||
gttoc(my_optimize);
|
|
||||||
return mpe;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** find the assignment of students to slots with most possible committees */
|
/** find the assignment of students to slots with most possible committees */
|
||||||
DiscreteValues Scheduler::bestSchedule() const {
|
DiscreteValues Scheduler::bestSchedule() const {
|
||||||
DiscreteValues best;
|
DiscreteValues best;
|
||||||
|
|
|
@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
|
||||||
/** Eliminate, return a Bayes net */
|
/** Eliminate, return a Bayes net */
|
||||||
DiscreteBayesNet::shared_ptr eliminate() const;
|
DiscreteBayesNet::shared_ptr eliminate() const;
|
||||||
|
|
||||||
/** Find the best total assignment - can be expensive */
|
|
||||||
DiscreteValues optimalAssignment() const;
|
|
||||||
|
|
||||||
/** find the assignment of students to slots with most possible committees */
|
/** find the assignment of students to slots with most possible committees */
|
||||||
DiscreteValues bestSchedule() const;
|
DiscreteValues bestSchedule() const;
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ void runLargeExample() {
|
||||||
// SETDEBUG("timing-verbose", true);
|
// SETDEBUG("timing-verbose", true);
|
||||||
SETDEBUG("DiscreteConditional::DiscreteConditional", true);
|
SETDEBUG("DiscreteConditional::DiscreteConditional", true);
|
||||||
gttic(large);
|
gttic(large);
|
||||||
auto MPE = scheduler.optimalAssignment();
|
auto MPE = scheduler.optimize();
|
||||||
gttoc(large);
|
gttoc(large);
|
||||||
tictoc_finishedIteration();
|
tictoc_finishedIteration();
|
||||||
tictoc_print();
|
tictoc_print();
|
||||||
|
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(6 - s);
|
DiscreteKey dkey = scheduler.studentKey(6 - s);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
size_t count = (*root)(values);
|
size_t count = (*root)(values);
|
||||||
|
|
||||||
|
@ -319,11 +319,11 @@ void accomodateStudent() {
|
||||||
// GTSAM_PRINT(*chordal);
|
// GTSAM_PRINT(*chordal);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(0);
|
DiscreteKey dkey = scheduler.studentKey(0);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
size_t count = (*root)(values);
|
size_t count = (*root)(values);
|
||||||
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)
|
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)
|
||||||
|
|
|
@ -143,7 +143,7 @@ void runLargeExample() {
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gttic(large);
|
gttic(large);
|
||||||
auto MPE = scheduler.optimalAssignment();
|
auto MPE = scheduler.optimize();
|
||||||
gttoc(large);
|
gttoc(large);
|
||||||
tictoc_finishedIteration();
|
tictoc_finishedIteration();
|
||||||
tictoc_print();
|
tictoc_print();
|
||||||
|
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
size_t count = (*root)(values);
|
size_t count = (*root)(values);
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,7 @@ void runLargeExample() {
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
gttic(large);
|
gttic(large);
|
||||||
auto MPE = scheduler.optimalAssignment();
|
auto MPE = scheduler.optimize();
|
||||||
gttoc(large);
|
gttoc(large);
|
||||||
tictoc_finishedIteration();
|
tictoc_finishedIteration();
|
||||||
tictoc_print();
|
tictoc_print();
|
||||||
|
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
|
||||||
root->print(""/*scheduler.studentName(s)*/);
|
root->print(""/*scheduler.studentName(s)*/);
|
||||||
|
|
||||||
// solve root node only
|
// solve root node only
|
||||||
DiscreteValues values;
|
size_t bestSlot = root->argmax();
|
||||||
size_t bestSlot = root->solve(values);
|
|
||||||
|
|
||||||
// get corresponding count
|
// get corresponding count
|
||||||
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
||||||
|
DiscreteValues values;
|
||||||
values[dkey.first] = bestSlot;
|
values[dkey.first] = bestSlot;
|
||||||
double count = (*root)(values);
|
double count = (*root)(values);
|
||||||
|
|
||||||
|
|
|
@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
|
||||||
EXPECT(assert_equal(expectedProduct, product));
|
EXPECT(assert_equal(expectedProduct, product));
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
auto mpe = csp.optimalAssignment();
|
auto mpe = csp.optimize();
|
||||||
DiscreteValues expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(expected, mpe));
|
||||||
|
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
|
||||||
csp.addAllDiff(WY, CO);
|
csp.addAllDiff(WY, CO);
|
||||||
csp.addAllDiff(CO, NM);
|
csp.addAllDiff(CO, NM);
|
||||||
|
|
||||||
|
DiscreteValues mpe;
|
||||||
|
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
|
||||||
|
|
||||||
// Create ordering according to example in ND-CSP.lyx
|
// Create ordering according to example in ND-CSP.lyx
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
|
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
|
||||||
Key(8), Key(9), Key(10);
|
Key(8), Key(9), Key(10);
|
||||||
// Solve using that ordering:
|
|
||||||
auto mpe = csp.optimalAssignment(ordering);
|
|
||||||
// GTSAM_PRINT(mpe);
|
|
||||||
DiscreteValues expected;
|
|
||||||
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
|
|
||||||
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
|
|
||||||
UT.first, 1)(AZ.first, 0);
|
|
||||||
|
|
||||||
// TODO: Fix me! mpe result seems to be right. (See the printing)
|
// Solve using that ordering:
|
||||||
// It has the same prob as the expected solution.
|
auto actualMPE = csp.optimize(ordering);
|
||||||
// Is mpe another solution, or the expected solution is unique???
|
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(mpe, actualMPE));
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
||||||
|
|
||||||
// Write out the dual graph for hmetis
|
// Write out the dual graph for hmetis
|
||||||
|
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
|
||||||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||||
|
|
||||||
// Solve
|
// Solve
|
||||||
auto mpe = csp.optimalAssignment();
|
auto mpe = csp.optimize();
|
||||||
DiscreteValues expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
||||||
EXPECT(assert_equal(expected, mpe));
|
EXPECT(assert_equal(expected, mpe));
|
||||||
|
|
|
@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
|
||||||
|
|
||||||
// Do exact inference
|
// Do exact inference
|
||||||
gttic(small);
|
gttic(small);
|
||||||
auto MPE = s.optimalAssignment();
|
auto MPE = s.optimize();
|
||||||
gttoc(small);
|
gttoc(small);
|
||||||
|
|
||||||
// print MPE, commented out as unit tests don't print
|
// print MPE, commented out as unit tests don't print
|
||||||
|
|
|
@ -100,7 +100,7 @@ class Sudoku : public CSP {
|
||||||
|
|
||||||
/// solve and print solution
|
/// solve and print solution
|
||||||
void printSolution() const {
|
void printSolution() const {
|
||||||
auto MPE = optimalAssignment();
|
auto MPE = optimize();
|
||||||
printAssignment(MPE);
|
printAssignment(MPE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
|
||||||
0, 1, 0, 0);
|
0, 1, 0, 0);
|
||||||
|
|
||||||
// optimize and check
|
// optimize and check
|
||||||
auto solution = csp.optimalAssignment();
|
auto solution = csp.optimize();
|
||||||
DiscreteValues expected;
|
DiscreteValues expected;
|
||||||
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
||||||
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
||||||
|
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
|
||||||
EXPECT_LONGS_EQUAL(16, new_csp.size());
|
EXPECT_LONGS_EQUAL(16, new_csp.size());
|
||||||
|
|
||||||
// Check that solution
|
// Check that solution
|
||||||
auto new_solution = new_csp.optimalAssignment();
|
auto new_solution = new_csp.optimize();
|
||||||
// csp.printAssignment(new_solution);
|
// csp.printAssignment(new_solution);
|
||||||
EXPECT(assert_equal(expected, new_solution));
|
EXPECT(assert_equal(expected, new_solution));
|
||||||
}
|
}
|
||||||
|
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
|
||||||
EXPECT_LONGS_EQUAL(81, new_csp.size());
|
EXPECT_LONGS_EQUAL(81, new_csp.size());
|
||||||
|
|
||||||
// Check that solution
|
// Check that solution
|
||||||
auto solution = new_csp.optimalAssignment();
|
auto solution = new_csp.optimize();
|
||||||
// csp.printAssignment(solution);
|
// csp.printAssignment(solution);
|
||||||
EXPECT_LONGS_EQUAL(6, solution.at(key99));
|
EXPECT_LONGS_EQUAL(6, solution.at(key99));
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||||
|
|
||||||
# solve
|
# solve
|
||||||
actualMPE = chordal.optimize()
|
actualMPE = fg.optimize()
|
||||||
expectedMPE = DiscreteValues()
|
expectedMPE = DiscreteValues()
|
||||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||||
expectedMPE[key[0]] = 0
|
expectedMPE[key[0]] = 0
|
||||||
|
@ -94,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
fg.add(Dyspnea, "0 1")
|
fg.add(Dyspnea, "0 1")
|
||||||
|
|
||||||
# solve again, now with evidence
|
# solve again, now with evidence
|
||||||
chordal2 = fg.eliminateSequential(ordering)
|
actualMPE2 = fg.optimize()
|
||||||
actualMPE2 = chordal2.optimize()
|
|
||||||
expectedMPE2 = DiscreteValues()
|
expectedMPE2 = DiscreteValues()
|
||||||
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
||||||
expectedMPE2[key[0]] = 0
|
expectedMPE2[key[0]] = 0
|
||||||
|
@ -105,6 +104,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
list(expectedMPE2.items()))
|
list(expectedMPE2.items()))
|
||||||
|
|
||||||
# now sample from it
|
# now sample from it
|
||||||
|
chordal2 = fg.eliminateSequential(ordering)
|
||||||
actualSample = chordal2.sample()
|
actualSample = chordal2.sample()
|
||||||
self.assertEqual(len(actualSample), 8)
|
self.assertEqual(len(actualSample), 8)
|
||||||
|
|
||||||
|
@ -122,10 +122,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
for key in [Asia, Smoking]:
|
for key in [Asia, Smoking]:
|
||||||
given[key[0]] = 0
|
given[key[0]] = 0
|
||||||
|
|
||||||
# Now optimize fragment:
|
|
||||||
actual = fragment.optimize(given)
|
|
||||||
self.assertEqual(len(actual), 5)
|
|
||||||
|
|
||||||
# Now sample from fragment:
|
# Now sample from fragment:
|
||||||
actual = fragment.sample(given)
|
actual = fragment.sample(given)
|
||||||
self.assertEqual(len(actual), 5)
|
self.assertEqual(len(actual), 5)
|
||||||
|
|
|
@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
|
||||||
X = 0, 2
|
X = 0, 2
|
||||||
|
|
||||||
|
|
||||||
class TestDiscretePrior(GtsamTestCase):
|
class TestDiscreteDistribution(GtsamTestCase):
|
||||||
"""Tests for Discrete Priors."""
|
"""Tests for Discrete Priors."""
|
||||||
|
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
|
|
Loading…
Reference in New Issue