Merge pull request #1050 from borglab/feature/betterMPE
commit
b441eea976
|
@ -53,10 +53,9 @@ int main(int argc, char **argv) {
|
|||
// Create solver and eliminate
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
|
||||
// solve
|
||||
auto mpe = chordal->optimize();
|
||||
auto mpe = fg.optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// We can also build a Bayes tree (directed junction tree).
|
||||
|
@ -69,14 +68,14 @@ int main(int argc, char **argv) {
|
|||
fg.add(Dyspnea, "0 1");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||
auto mpe2 = chordal2->optimize();
|
||||
auto mpe2 = fg.optimize();
|
||||
GTSAM_PRINT(mpe2);
|
||||
|
||||
// We can also sample from it
|
||||
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
auto sample = chordal2->sample();
|
||||
auto sample = chordal->sample();
|
||||
GTSAM_PRINT(sample);
|
||||
}
|
||||
return 0;
|
||||
|
|
|
@ -85,7 +85,7 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
|
||||
// "Most Probable Explanation", i.e., configuration with largest value
|
||||
auto mpe = graph.eliminateSequential()->optimize();
|
||||
auto mpe = graph.optimize();
|
||||
cout << "\nMost Probable Explanation (MPE):" << endl;
|
||||
print(mpe);
|
||||
|
||||
|
@ -96,8 +96,7 @@ int main(int argc, char **argv) {
|
|||
graph.add(Cloudy, "1 0");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto mpe_with_evidence = chordal->optimize();
|
||||
auto mpe_with_evidence = graph.optimize();
|
||||
|
||||
cout << "\nMPE given C=0:" << endl;
|
||||
print(mpe_with_evidence);
|
||||
|
@ -110,7 +109,8 @@ int main(int argc, char **argv) {
|
|||
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
|
||||
<< endl;
|
||||
|
||||
// We can also sample from it
|
||||
// We can also sample from the eliminated graph
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
auto sample = chordal->sample();
|
||||
|
|
|
@ -59,16 +59,16 @@ int main(int argc, char **argv) {
|
|||
// Convert to factor graph
|
||||
DiscreteFactorGraph factorGraph(hmm);
|
||||
|
||||
// Do max-prodcut
|
||||
auto mpe = factorGraph.optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// Create solver and eliminate
|
||||
// This will create a DAG ordered with arrow of time reversed
|
||||
DiscreteBayesNet::shared_ptr chordal =
|
||||
factorGraph.eliminateSequential(ordering);
|
||||
chordal->print("Eliminated");
|
||||
|
||||
// solve
|
||||
auto mpe = chordal->optimize();
|
||||
GTSAM_PRINT(mpe);
|
||||
|
||||
// We can also sample from it
|
||||
cout << "\n10 samples:" << endl;
|
||||
for (size_t k = 0; k < 10; k++) {
|
||||
|
|
|
@ -68,9 +68,8 @@ int main(int argc, char** argv) {
|
|||
<< graph.size() << " factors (Unary+Edge).";
|
||||
|
||||
// "Decoding", i.e., configuration with largest value
|
||||
// We use sequential variable elimination
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto optimalDecoding = chordal->optimize();
|
||||
// Uses max-product.
|
||||
auto optimalDecoding = graph.optimize();
|
||||
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");
|
||||
|
||||
// "Inference" Computing marginals for each node
|
||||
|
|
|
@ -61,9 +61,8 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// "Decoding", i.e., configuration with largest value (MPE)
|
||||
// We use sequential variable elimination
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto optimalDecoding = chordal->optimize();
|
||||
// Uses max-product
|
||||
auto optimalDecoding = graph.optimize();
|
||||
GTSAM_PRINT(optimalDecoding);
|
||||
|
||||
// "Inference" Computing marginals
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <boost/format.hpp>
|
||||
#include <utility>
|
||||
|
||||
using namespace std;
|
||||
|
@ -67,6 +68,10 @@ namespace gtsam {
|
|||
void DecisionTreeFactor::print(const string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << s;
|
||||
cout << " f[";
|
||||
for (auto&& key : keys())
|
||||
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
|
||||
cout << " ]" << endl;
|
||||
ADT::print("Potentials:", formatter);
|
||||
}
|
||||
|
||||
|
|
|
@ -127,11 +127,16 @@ namespace gtsam {
|
|||
return combine(keys, ADT::Ring::add);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator values
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(size_t nrFrontals) const {
|
||||
return combine(nrFrontals, ADT::Ring::max);
|
||||
}
|
||||
|
||||
/// Create new factor by maximizing over all values with the same separator.
|
||||
shared_ptr max(const Ordering& keys) const {
|
||||
return combine(keys, ADT::Ring::max);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
|
|
@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||
DiscreteValues result;
|
||||
return optimize(result);
|
||||
|
@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const {
|
|||
|
||||
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
#ifdef _MSC_VER
|
||||
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
|
||||
#else
|
||||
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
|
||||
#endif
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->solveInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
|
|
|
@ -31,12 +31,12 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/** A Bayes net made from linear-Discrete densities */
|
||||
/** A Bayes net made from discrete conditional distributions. */
|
||||
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
|
||||
{
|
||||
public:
|
||||
|
||||
typedef FactorGraph<DiscreteConditional> Base;
|
||||
typedef BayesNet<DiscreteConditional> Base;
|
||||
typedef DiscreteBayesNet This;
|
||||
typedef DiscreteConditional ConditionalType;
|
||||
typedef boost::shared_ptr<This> shared_ptr;
|
||||
|
@ -45,7 +45,7 @@ namespace gtsam {
|
|||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty factor graph */
|
||||
/// Construct empty Bayes net.
|
||||
DiscreteBayesNet() {}
|
||||
|
||||
/** Construct from iterator over conditionals */
|
||||
|
@ -98,27 +98,6 @@ namespace gtsam {
|
|||
return evaluate(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief solve by back-substitution.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first. If the Bayes net resulted from
|
||||
* eliminating a factor graph, this is true for the elimination ordering.
|
||||
*
|
||||
* @return a sampled value for all variables.
|
||||
*/
|
||||
DiscreteValues optimize() const;
|
||||
|
||||
/**
|
||||
* @brief solve by back-substitution, given certain variables.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||
* Bayes net does not contain any conditionals for the given values.
|
||||
*
|
||||
* @return given values extended with optimized value for other variables.
|
||||
*/
|
||||
DiscreteValues optimize(DiscreteValues given) const;
|
||||
|
||||
/**
|
||||
* @brief do ancestral sampling
|
||||
*
|
||||
|
@ -154,6 +133,15 @@ namespace gtsam {
|
|||
|
||||
///@}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
|
||||
DiscreteValues GTSAM_DEPRECATED optimize() const;
|
||||
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -16,26 +16,25 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using std::pair;
|
||||
using std::stringstream;
|
||||
using std::vector;
|
||||
using std::pair;
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
|
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
|
|||
cout << endl;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
|
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||
const DiscreteValues& given,
|
||||
bool forceComplete = true) {
|
||||
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& given, bool forceComplete) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
DiscreteConditional::ADT adt(conditional);
|
||||
DiscreteConditional::ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : conditional.parents()) {
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = given.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
|
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
|||
if (forceComplete) {
|
||||
given.print("parentsValues: ");
|
||||
throw runtime_error(
|
||||
"DiscreteConditional::Choose: parent value missing");
|
||||
"DiscreteConditional::choose: parent value missing");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
|||
/* ************************************************************************** */
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& given) const {
|
||||
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
||||
ADT adt = choose(given, false); // P(F|S=given)
|
||||
|
||||
// Collect all keys not in given.
|
||||
DiscreteKeys dKeys;
|
||||
|
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ****************************************************************************/
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t parent_value) const {
|
||||
if (nrFrontals() != 1)
|
||||
|
@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
|
@ -248,23 +247,66 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
|||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the MPE
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to mpe
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
size_t max = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
max = value;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::argmax() const {
|
||||
size_t maxValue = 0;
|
||||
double maxP = 0;
|
||||
assert(nrFrontals() == 1);
|
||||
assert(nrParents() == 0);
|
||||
DiscreteValues frontals;
|
||||
Key j = firstFrontalKey();
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = (*this)(frontals);
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
maxValue = value;
|
||||
}
|
||||
}
|
||||
return maxValue;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
|
@ -273,34 +315,11 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
DiscreteValues frontals;
|
||||
double maxP = 0;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
if (nrFrontals() != 1) {
|
||||
|
@ -323,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
return distribution(rng);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
|
@ -334,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
|||
return sample(values);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::sample() const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
|
|
|
@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return MPE value of the child (1 frontal variable).
|
||||
*/
|
||||
size_t solve(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
|
@ -199,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// Zero parent version.
|
||||
size_t sample() const;
|
||||
|
||||
/**
|
||||
* @brief Return assignment that maximizes distribution.
|
||||
* @return Optimal assignment (1 frontal variable).
|
||||
*/
|
||||
size_t argmax() const;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
||||
/// solve a conditional, in place
|
||||
void solveInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
/// sample in place, stores result in partial solution
|
||||
void sampleInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
|
@ -228,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const;
|
||||
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// Internal version of choose
|
||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||
bool forceComplete) const;
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
|
|
|
@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional {
|
|||
/// Return entire probability mass function.
|
||||
std::vector<double> pmf() const;
|
||||
|
||||
/**
|
||||
* solve a conditional
|
||||
* @return MPE value of the child (1 frontal variable).
|
||||
*/
|
||||
size_t solve() const { return Base::solve({}); }
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @return sample from conditional
|
||||
*/
|
||||
size_t sample() const { return Base::sample(); }
|
||||
|
||||
/// @}
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
/// @name Deprecated functionality
|
||||
/// @{
|
||||
size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); }
|
||||
/// @}
|
||||
#endif
|
||||
};
|
||||
// DiscreteDistribution
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph-inst.h>
|
||||
#include <gtsam/inference/FactorGraph-inst.h>
|
||||
|
||||
|
@ -95,22 +96,85 @@ namespace gtsam {
|
|||
// }
|
||||
// }
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteFactorGraph::optimize() const
|
||||
{
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
return BaseEliminateable::eliminateSequential()->optimize();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
// Alternate eliminate function for MPE
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) {
|
||||
|
||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for(const DiscreteFactor::shared_ptr& factor: factors)
|
||||
product = (*factor) * product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// max out frontals, this is the factor on the separator
|
||||
gttic(max);
|
||||
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
|
||||
gttoc(max);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
DiscreteKeys orderedKeys;
|
||||
for (auto&& key : frontalKeys)
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
for (auto&& key : max->keys())
|
||||
orderedKeys.emplace_back(key, product.cardinality(key));
|
||||
|
||||
// Make lookup with product
|
||||
gttic(lookup);
|
||||
size_t nrFrontals = frontalKeys.size();
|
||||
auto lookup = boost::make_shared<DiscreteLookupTable>(nrFrontals,
|
||||
orderedKeys, product);
|
||||
gttoc(lookup);
|
||||
|
||||
return std::make_pair(
|
||||
boost::dynamic_pointer_cast<DiscreteConditional>(lookup), max);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// The max-product solution below is a bit clunky: the elimination machinery
|
||||
// does not allow for differently *typed* versions of elimination, so we
|
||||
// eliminate into a Bayes Net using the special eliminate function above, and
|
||||
// then create the DiscreteLookupDAG after the fact, in linear time.
|
||||
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
auto bayesNet =
|
||||
BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE);
|
||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||
}
|
||||
|
||||
DiscreteLookupDAG DiscreteFactorGraph::maxProduct(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_maxProduct);
|
||||
auto bayesNet =
|
||||
BaseEliminateable::eliminateSequential(ordering, EliminateForMPE);
|
||||
return DiscreteLookupDAG::FromBayesNet(*bayesNet);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
OptionalOrderingType orderingType) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
DiscreteLookupDAG dag = maxProduct(orderingType);
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteFactorGraph::optimize(
|
||||
const Ordering& ordering) const {
|
||||
gttic(DiscreteFactorGraph_optimize);
|
||||
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||
return dag.argmax();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
|
@ -120,15 +184,18 @@ namespace gtsam {
|
|||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||
frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(),
|
||||
sum->keys().end());
|
||||
|
||||
// now divide product/sum to get conditional
|
||||
gttic(divide);
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys));
|
||||
auto conditional =
|
||||
boost::make_shared<DiscreteConditional>(product, *sum, orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
return std::make_pair(cond, sum);
|
||||
return std::make_pair(conditional, sum);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -18,10 +18,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/inference/EliminateableFactorGraph.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/base/FastSet.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
@ -128,18 +129,39 @@ class GTSAM_EXPORT DiscreteFactorGraph
|
|||
const std::string& s = "DiscreteFactorGraph",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/** Solve the factor graph by performing variable elimination in COLAMD order using
|
||||
* the dense elimination function specified in \c function,
|
||||
* followed by back-substitution resulting from elimination. Is equivalent
|
||||
* to calling graph.eliminateSequential()->optimize(). */
|
||||
DiscreteValues optimize() const;
|
||||
/**
|
||||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM
|
||||
* @return DiscreteLookupDAG::shared_ptr DAG with lookup tables
|
||||
*/
|
||||
DiscreteLookupDAG maxProduct(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Implement the max-product algorithm
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables
|
||||
*/
|
||||
DiscreteLookupDAG maxProduct(const Ordering& ordering) const;
|
||||
|
||||
// /** Permute the variables in the factors */
|
||||
// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation);
|
||||
//
|
||||
// /** Apply a reduction, which is a remapping of variable indices. */
|
||||
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);
|
||||
/**
|
||||
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||
*
|
||||
* @param orderingType
|
||||
* @return DiscreteValues : MPE
|
||||
*/
|
||||
DiscreteValues optimize(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Find the maximum probable explanation (MPE) by doing max-product.
|
||||
*
|
||||
* @param ordering
|
||||
* @return DiscreteValues : MPE
|
||||
*/
|
||||
DiscreteValues optimize(const Ordering& ordering) const;
|
||||
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -33,16 +33,13 @@ namespace gtsam {
|
|||
|
||||
KeyVector DiscreteKeys::indices() const {
|
||||
KeyVector js;
|
||||
for(const DiscreteKey& key: *this)
|
||||
js.push_back(key.first);
|
||||
for (const DiscreteKey& key : *this) js.push_back(key.first);
|
||||
return js;
|
||||
}
|
||||
|
||||
map<Key, size_t> DiscreteKeys::cardinalities() const {
|
||||
map<Key, size_t> cs;
|
||||
cs.insert(begin(), end());
|
||||
// for(const DiscreteKey& key: *this)
|
||||
// cs.insert(key);
|
||||
return cs;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.cpp
|
||||
* @date Feb 14, 2011
|
||||
* @author Duy-Nguyen Ta
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using std::pair;
|
||||
using std::vector;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
void DiscreteLookupTable::print(const std::string& s,
|
||||
const KeyFormatter& formatter) const {
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
cout << s << " g( ";
|
||||
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
if (nrParents()) {
|
||||
cout << "; ";
|
||||
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||
cout << formatter(*it) << " ";
|
||||
}
|
||||
}
|
||||
cout << "):\n";
|
||||
ADT::print("", formatter);
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
double maxP = 0;
|
||||
|
||||
// Get all Possible Configurations
|
||||
const auto allPosbValues = frontalAssignments();
|
||||
|
||||
// Find the maximum
|
||||
for (const auto& frontalVals : allPosbValues) {
|
||||
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||
// Update maximum solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = frontalVals;
|
||||
}
|
||||
}
|
||||
|
||||
// set values (inPlace) to maximum
|
||||
for (Key j : frontals()) {
|
||||
(*values)[j] = mpe[j];
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||
const DiscreteBayesNet& bayesNet) {
|
||||
DiscreteLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
if (auto lookupTable =
|
||||
boost::dynamic_pointer_cast<DiscreteLookupTable>(conditional)) {
|
||||
dag.push_back(lookupTable);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"DiscreteFactorGraph::maxProduct: Expected look up table.");
|
||||
}
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||
lookupTable->argmaxInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
|
||||
} // namespace gtsam
|
|
@ -0,0 +1,140 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.h
|
||||
* @date January, 2022
|
||||
* @author Frank dellaert
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class DiscreteBayesNet;
|
||||
|
||||
/**
|
||||
* @brief DiscreteLookupTable table for max-product
|
||||
*
|
||||
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||
* Is used in the max-product algorithm.
|
||||
*/
|
||||
class DiscreteLookupTable : public DiscreteConditional {
|
||||
public:
|
||||
using This = DiscreteLookupTable;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Discrete Lookup Table object
|
||||
*
|
||||
* @param nFrontals number of frontal variables
|
||||
* @param keys a orted list of gtsam::Keys
|
||||
* @param potentials the algebraic decision tree with lookup values
|
||||
*/
|
||||
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
const std::string& s = "Discrete Lookup Table: ",
|
||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/**
|
||||
* @brief return assignment for single frontal variable that maximizes value.
|
||||
* @param parentsValues Known assignments for the parents.
|
||||
* @return maximizing assignment for the frontal variable.
|
||||
*/
|
||||
size_t argmax(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||
};
|
||||
|
||||
/** A DAG made from lookup tables, as defined above. */
|
||||
class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
||||
public:
|
||||
using Base = BayesNet<DiscreteLookupTable>;
|
||||
using This = DiscreteLookupDAG;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct empty DAG.
|
||||
DiscreteLookupDAG() {}
|
||||
|
||||
/// Create from BayesNet with LookupTables
|
||||
static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet);
|
||||
|
||||
/// Destructor
|
||||
virtual ~DiscreteLookupDAG() {}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This& bn, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteLookupTable */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief argmax by back-substitution, optionally given certain variables.
|
||||
*
|
||||
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first *and* that the
|
||||
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||
* resulted from eliminating a factor graph, this is true for the elimination
|
||||
* ordering.
|
||||
*
|
||||
* @return given assignment extended w. optimal assignment for all variables.
|
||||
*/
|
||||
DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteLookupDAG> : public Testable<DiscreteLookupDAG> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||
size_t sample(size_t value) const;
|
||||
size_t sample() const;
|
||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
|
@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(size_t value) const;
|
||||
std::vector<double> pmf() const;
|
||||
size_t solve() const;
|
||||
size_t argmax() const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
|
@ -163,8 +161,6 @@ class DiscreteBayesNet {
|
|||
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
|
@ -217,6 +213,21 @@ class DiscreteBayesTree {
|
|||
std::map<gtsam::Key, std::vector<std::string>> names) const;
|
||||
};
|
||||
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
class DiscreteLookupDAG {
|
||||
DiscreteLookupDAG();
|
||||
void push_back(const gtsam::DiscreteLookupTable* table);
|
||||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeySet keys() const;
|
||||
const gtsam::DiscreteLookupTable* at(size_t i) const;
|
||||
void print(string s = "DiscreteLookupDAG\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DiscreteValues argmax() const;
|
||||
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
|
||||
};
|
||||
|
||||
#include <gtsam/inference/DotWriter.h>
|
||||
class DotWriter {
|
||||
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
|
||||
|
@ -260,6 +271,9 @@ class DiscreteFactorGraph {
|
|||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
|
||||
gtsam::DiscreteLookupDAG maxProduct();
|
||||
gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering);
|
||||
|
||||
gtsam::DiscreteBayesNet eliminateSequential();
|
||||
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
|
||||
std::pair<gtsam::DiscreteBayesNet, gtsam::DiscreteFactorGraph>
|
||||
|
|
|
@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) {
|
|||
DiscreteConditional expected2(Bronchitis % "11/9");
|
||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||
|
||||
// solve
|
||||
auto actualMPE = chordal->optimize();
|
||||
DiscreteValues expectedMPE;
|
||||
insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)(
|
||||
LungCancer.first, 0)(Bronchitis.first, 0);
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
|
||||
// add evidence, we were in Asia and we have dyspnea
|
||||
fg.add(Asia, "0 1");
|
||||
fg.add(Dyspnea, "0 1");
|
||||
|
||||
// solve again, now with evidence
|
||||
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
|
||||
auto actualMPE2 = chordal2->optimize();
|
||||
DiscreteValues expectedMPE2;
|
||||
insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)(
|
||||
Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)(
|
||||
LungCancer.first, 0)(Bronchitis.first, 1);
|
||||
EXPECT(assert_equal(expectedMPE2, actualMPE2));
|
||||
EXPECT(assert_equal(expected2, *chordal->back()));
|
||||
|
||||
// now sample from it
|
||||
DiscreteValues expectedSample;
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testDiscretePrior.cpp
|
||||
* @file testDiscreteDistribution.cpp
|
||||
* @brief unit tests for DiscreteDistribution
|
||||
* @author Frank dellaert
|
||||
* @date December 2021
|
||||
|
@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) {
|
|||
prior.sample();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteDistribution, argmax) {
|
||||
DiscreteDistribution prior(X % "2/3");
|
||||
EXPECT_LONGS_EQUAL(prior.argmax(), 1);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) {
|
|||
graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||
graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0");
|
||||
|
||||
// graph.print("Graph: ");
|
||||
DecisionTreeFactor product = graph.product();
|
||||
DecisionTreeFactor::shared_ptr sum = product.sum(1);
|
||||
// sum->print("Debug SUM: ");
|
||||
DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum));
|
||||
|
||||
// cond->print("marginal:");
|
||||
|
||||
// pair<DiscreteBayesNet::shared_ptr, DiscreteFactor::shared_ptr> result = EliminateDiscrete(graph, 1);
|
||||
// result.first->print("BayesNet: ");
|
||||
// result.second->print("New factor: ");
|
||||
//
|
||||
Ordering ordering;
|
||||
ordering += Key(0),Key(1),Key(2),Key(3);
|
||||
DiscreteEliminationTree eliminationTree(graph, ordering);
|
||||
// eliminationTree.print("Elimination tree: ");
|
||||
eliminationTree.eliminate(EliminateDiscrete);
|
||||
// solver.optimize();
|
||||
// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate();
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0);
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -115,8 +101,7 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, test)
|
||||
{
|
||||
TEST(DiscreteFactorGraph, test) {
|
||||
// Declare keys and ordering
|
||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||
|
||||
|
@ -127,55 +112,52 @@ TEST( DiscreteFactorGraph, test)
|
|||
graph.add(C & B, "3 1 1 3");
|
||||
|
||||
// Test EliminateDiscrete
|
||||
// FIXME: apparently Eliminate returns a conditional rather than a net
|
||||
Ordering frontalKeys;
|
||||
frontalKeys += Key(0);
|
||||
DiscreteConditional::shared_ptr conditional;
|
||||
DecisionTreeFactor::shared_ptr newFactor;
|
||||
boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys);
|
||||
|
||||
// Check Bayes net
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
DiscreteBayesNet expected;
|
||||
Signature signature((C | B, A) = "9/1 1/1 1/1 1/9");
|
||||
// cout << signature << endl;
|
||||
DiscreteConditional expectedConditional(signature);
|
||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||
expected.add(signature);
|
||||
|
||||
// Check Factor
|
||||
CHECK(newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
||||
|
||||
// add conditionals to complete expected Bayes net
|
||||
expected.add(B | A = "5/3 3/5");
|
||||
expected.add(A % "1/1");
|
||||
// GTSAM_PRINT(expected);
|
||||
|
||||
// Test elimination tree
|
||||
// Test using elimination tree
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2);
|
||||
DiscreteEliminationTree etree(graph, ordering);
|
||||
DiscreteBayesNet::shared_ptr actual;
|
||||
DiscreteFactorGraph::shared_ptr remainingGraph;
|
||||
boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete);
|
||||
EXPECT(assert_equal(expected, *actual));
|
||||
|
||||
// // Test solver
|
||||
// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate();
|
||||
// EXPECT(assert_equal(expected, *actual2));
|
||||
// Check Bayes net
|
||||
DiscreteBayesNet expectedBayesNet;
|
||||
expectedBayesNet.add(signature);
|
||||
expectedBayesNet.add(B | A = "5/3 3/5");
|
||||
expectedBayesNet.add(A % "1/1");
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual));
|
||||
|
||||
// Test optimization
|
||||
DiscreteValues expectedValues;
|
||||
insert(expectedValues)(0, 0)(1, 0)(2, 0);
|
||||
auto actualValues = graph.optimize();
|
||||
EXPECT(assert_equal(expectedValues, actualValues));
|
||||
// Test eliminateSequential
|
||||
DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering);
|
||||
EXPECT(assert_equal(expectedBayesNet, *actual2));
|
||||
|
||||
// Test mpe
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 0)(2, 0);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, testMPE)
|
||||
{
|
||||
TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) {
|
||||
// Declare a bunch of keys
|
||||
DiscreteKey C(0, 2), A(1, 2), B(2, 2);
|
||||
|
||||
|
@ -183,19 +165,54 @@ TEST( DiscreteFactorGraph, testMPE)
|
|||
DiscreteFactorGraph graph;
|
||||
graph.add(C & A, "0.2 0.8 0.3 0.7");
|
||||
graph.add(C & B, "0.1 0.9 0.4 0.6");
|
||||
// graph.product().print();
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
|
||||
auto actualMPE = graph.optimize();
|
||||
// Created expected MPE
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 0)(1, 1)(2, 1);
|
||||
|
||||
DiscreteValues expectedMPE;
|
||||
insert(expectedMPE)(0, 0)(1, 1)(2, 1);
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
// Do max-product with different orderings
|
||||
for (Ordering::OrderingType orderingType :
|
||||
{Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL,
|
||||
Ordering::CUSTOM}) {
|
||||
DiscreteLookupDAG dag = graph.maxProduct(orderingType);
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
auto actualMPE2 = graph.optimize(); // all in one
|
||||
EXPECT(assert_equal(mpe, actualMPE2));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
||||
{
|
||||
TEST(DiscreteFactorGraph, marginalIsNotMPE) {
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create Bayes net such that marginal on A is bigger for 0 than 1, but the
|
||||
// MPE does not have A=0.
|
||||
DiscreteBayesNet bayesNet;
|
||||
bayesNet.add(B | A = "1/1 1/2");
|
||||
bayesNet.add(A % "10/9");
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// Which we verify using max-product:
|
||||
DiscreteFactorGraph graph(bayesNet);
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression
|
||||
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
// Optimize on BayesNet maximizes marginal, then the conditional marginals:
|
||||
auto notOptimal = bayesNet.optimize();
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
||||
// The factor graph in Darwiche09book, page 244
|
||||
DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2);
|
||||
|
||||
|
@ -207,52 +224,34 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244)
|
|||
graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95");
|
||||
graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0");
|
||||
graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche)
|
||||
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1);
|
||||
EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression
|
||||
// You can check visually by printing product:
|
||||
// graph.product().print("Darwiche-product");
|
||||
// graph.product().potentials().dot("Darwiche-product");
|
||||
// DiscreteSequentialSolver(graph).eliminate()->print();
|
||||
|
||||
DiscreteValues expectedMPE;
|
||||
insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1);
|
||||
|
||||
// Use the solver machinery.
|
||||
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
|
||||
auto actualMPE = chordal->optimize();
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
// DiscreteConditional::shared_ptr root = chordal->back();
|
||||
// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9);
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it now
|
||||
// typedef JunctionTreeOrdered<DiscreteFactorGraph> JT;
|
||||
// GenericMultifrontalSolver<DiscreteFactor, JT> solver(graph);
|
||||
// BayesTreeOrdered<DiscreteConditional>::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete);
|
||||
//// bayesTree->print("Bayes Tree");
|
||||
// EXPECT_LONGS_EQUAL(2,bayesTree->size());
|
||||
// Check MPE.
|
||||
auto actualMPE = graph.optimize();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
|
||||
// Check Bayes Net
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||
DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering);
|
||||
auto chordal = graph.eliminateSequential(ordering);
|
||||
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
auto notOptimal = chordal->optimize(); // not MPE !
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
#endif
|
||||
|
||||
// Let us create the Bayes tree here, just for fun, because we don't use it
|
||||
DiscreteBayesTree::shared_ptr bayesTree =
|
||||
graph.eliminateMultifrontal(ordering);
|
||||
// bayesTree->print("Bayes Tree");
|
||||
EXPECT_LONGS_EQUAL(2, bayesTree->size());
|
||||
|
||||
#ifdef OLD
|
||||
// Create the elimination tree manually
|
||||
VariableIndexOrdered structure(graph);
|
||||
typedef EliminationTreeOrdered<DiscreteFactor> ETree;
|
||||
ETree::shared_ptr eTree = ETree::Create(graph, structure);
|
||||
//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<");
|
||||
|
||||
// eliminate normally and check solution
|
||||
DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete);
|
||||
// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<");
|
||||
auto actualMPE = optimize(*bayesNet);
|
||||
EXPECT(assert_equal(expectedMPE, actualMPE));
|
||||
|
||||
// Approximate and check solution
|
||||
// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate();
|
||||
// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<");
|
||||
// EXPECT(assert_equal(expectedMPE, *actualMPE));
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef OLD
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testDiscreteLookupDAG.cpp
|
||||
*
|
||||
* @date January, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
|
||||
#include <boost/assign/list_inserter.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
||||
using namespace gtsam;
|
||||
using namespace boost::assign;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteLookupDAG, argmax) {
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||
DiscreteLookupDAG dag;
|
||||
|
||||
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||
|
||||
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||
dag.add(1, DiscreteKeys{A}, adtA);
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// check:
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
|
@ -25,15 +25,12 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* TODO: Update comments. The following comments are out of date!!!
|
||||
*
|
||||
* Base class for conditional densities, templated on KEY type. This class
|
||||
* provides storage for the keys involved in a conditional, and iterators and
|
||||
* Base class for conditional densities. This class iterators and
|
||||
* access to the frontal and separator keys.
|
||||
*
|
||||
* Derived classes *must* redefine the Factor and shared_ptr typedefs to refer
|
||||
* to the associated factor type and shared_ptr type of the derived class. See
|
||||
* IndexConditional and GaussianConditional for examples.
|
||||
* SymbolicConditional and GaussianConditional for examples.
|
||||
* \nosubgrouping
|
||||
*/
|
||||
template<class FACTOR, class DERIVEDCONDITIONAL>
|
||||
|
|
|
@ -14,18 +14,6 @@ using namespace std;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/// Find the best total assignment - can be expensive
|
||||
DiscreteValues CSP::optimalAssignment() const {
|
||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential();
|
||||
return chordal->optimize();
|
||||
}
|
||||
|
||||
/// Find the best total assignment - can be expensive
|
||||
DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const {
|
||||
DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering);
|
||||
return chordal->optimize();
|
||||
}
|
||||
|
||||
bool CSP::runArcConsistency(const VariableIndex& index,
|
||||
Domains* domains) const {
|
||||
bool changed = false;
|
||||
|
|
|
@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph {
|
|||
// return result;
|
||||
// }
|
||||
|
||||
/// Find the best total assignment - can be expensive.
|
||||
DiscreteValues optimalAssignment() const;
|
||||
|
||||
/// Find the best total assignment, with given ordering - can be expensive.
|
||||
DiscreteValues optimalAssignment(const Ordering& ordering) const;
|
||||
|
||||
// /*
|
||||
// * Perform loopy belief propagation
|
||||
// * True belief propagation would check for each value in domain
|
||||
|
|
|
@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const {
|
|||
return chordal;
|
||||
}
|
||||
|
||||
/** Find the best total assignment - can be expensive */
|
||||
DiscreteValues Scheduler::optimalAssignment() const {
|
||||
DiscreteBayesNet::shared_ptr chordal = eliminate();
|
||||
|
||||
if (ISDEBUG("Scheduler::optimalAssignment")) {
|
||||
DiscreteBayesNet::const_iterator it = chordal->end() - 1;
|
||||
const Student& student = students_.front();
|
||||
cout << endl;
|
||||
(*it)->print(student.name_);
|
||||
}
|
||||
|
||||
gttic(my_optimize);
|
||||
DiscreteValues mpe = chordal->optimize();
|
||||
gttoc(my_optimize);
|
||||
return mpe;
|
||||
}
|
||||
|
||||
/** find the assignment of students to slots with most possible committees */
|
||||
DiscreteValues Scheduler::bestSchedule() const {
|
||||
DiscreteValues best;
|
||||
|
|
|
@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP {
|
|||
/** Eliminate, return a Bayes net */
|
||||
DiscreteBayesNet::shared_ptr eliminate() const;
|
||||
|
||||
/** Find the best total assignment - can be expensive */
|
||||
DiscreteValues optimalAssignment() const;
|
||||
|
||||
/** find the assignment of students to slots with most possible committees */
|
||||
DiscreteValues bestSchedule() const;
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ void runLargeExample() {
|
|||
// SETDEBUG("timing-verbose", true);
|
||||
SETDEBUG("DiscreteConditional::DiscreteConditional", true);
|
||||
gttic(large);
|
||||
auto MPE = scheduler.optimalAssignment();
|
||||
auto MPE = scheduler.optimize();
|
||||
gttoc(large);
|
||||
tictoc_finishedIteration();
|
||||
tictoc_print();
|
||||
|
@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) {
|
|||
root->print(""/*scheduler.studentName(s)*/);
|
||||
|
||||
// solve root node only
|
||||
DiscreteValues values;
|
||||
size_t bestSlot = root->solve(values);
|
||||
size_t bestSlot = root->argmax();
|
||||
|
||||
// get corresponding count
|
||||
DiscreteKey dkey = scheduler.studentKey(6 - s);
|
||||
DiscreteValues values;
|
||||
values[dkey.first] = bestSlot;
|
||||
size_t count = (*root)(values);
|
||||
|
||||
|
@ -319,11 +319,11 @@ void accomodateStudent() {
|
|||
// GTSAM_PRINT(*chordal);
|
||||
|
||||
// solve root node only
|
||||
DiscreteValues values;
|
||||
size_t bestSlot = root->solve(values);
|
||||
size_t bestSlot = root->argmax();
|
||||
|
||||
// get corresponding count
|
||||
DiscreteKey dkey = scheduler.studentKey(0);
|
||||
DiscreteValues values;
|
||||
values[dkey.first] = bestSlot;
|
||||
size_t count = (*root)(values);
|
||||
cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0)
|
||||
|
|
|
@ -143,7 +143,7 @@ void runLargeExample() {
|
|||
}
|
||||
#else
|
||||
gttic(large);
|
||||
auto MPE = scheduler.optimalAssignment();
|
||||
auto MPE = scheduler.optimize();
|
||||
gttoc(large);
|
||||
tictoc_finishedIteration();
|
||||
tictoc_print();
|
||||
|
@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) {
|
|||
root->print(""/*scheduler.studentName(s)*/);
|
||||
|
||||
// solve root node only
|
||||
DiscreteValues values;
|
||||
size_t bestSlot = root->solve(values);
|
||||
size_t bestSlot = root->argmax();
|
||||
|
||||
// get corresponding count
|
||||
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
||||
DiscreteValues values;
|
||||
values[dkey.first] = bestSlot;
|
||||
size_t count = (*root)(values);
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ void runLargeExample() {
|
|||
}
|
||||
#else
|
||||
gttic(large);
|
||||
auto MPE = scheduler.optimalAssignment();
|
||||
auto MPE = scheduler.optimize();
|
||||
gttoc(large);
|
||||
tictoc_finishedIteration();
|
||||
tictoc_print();
|
||||
|
@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) {
|
|||
root->print(""/*scheduler.studentName(s)*/);
|
||||
|
||||
// solve root node only
|
||||
DiscreteValues values;
|
||||
size_t bestSlot = root->solve(values);
|
||||
size_t bestSlot = root->argmax();
|
||||
|
||||
// get corresponding count
|
||||
DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s);
|
||||
DiscreteValues values;
|
||||
values[dkey.first] = bestSlot;
|
||||
double count = (*root)(values);
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ TEST(CSP, allInOne) {
|
|||
EXPECT(assert_equal(expectedProduct, product));
|
||||
|
||||
// Solve
|
||||
auto mpe = csp.optimalAssignment();
|
||||
auto mpe = csp.optimize();
|
||||
DiscreteValues expected;
|
||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1);
|
||||
EXPECT(assert_equal(expected, mpe));
|
||||
|
@ -172,22 +172,18 @@ TEST(CSP, WesternUS) {
|
|||
csp.addAllDiff(WY, CO);
|
||||
csp.addAllDiff(CO, NM);
|
||||
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0);
|
||||
|
||||
// Create ordering according to example in ND-CSP.lyx
|
||||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7),
|
||||
Key(8), Key(9), Key(10);
|
||||
// Solve using that ordering:
|
||||
auto mpe = csp.optimalAssignment(ordering);
|
||||
// GTSAM_PRINT(mpe);
|
||||
DiscreteValues expected;
|
||||
insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)(
|
||||
MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)(
|
||||
UT.first, 1)(AZ.first, 0);
|
||||
|
||||
// TODO: Fix me! mpe result seems to be right. (See the printing)
|
||||
// It has the same prob as the expected solution.
|
||||
// Is mpe another solution, or the expected solution is unique???
|
||||
EXPECT(assert_equal(expected, mpe));
|
||||
// Solve using that ordering:
|
||||
auto actualMPE = csp.optimize(ordering);
|
||||
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9);
|
||||
|
||||
// Write out the dual graph for hmetis
|
||||
|
@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) {
|
|||
EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9);
|
||||
|
||||
// Solve
|
||||
auto mpe = csp.optimalAssignment();
|
||||
auto mpe = csp.optimize();
|
||||
DiscreteValues expected;
|
||||
insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2);
|
||||
EXPECT(assert_equal(expected, mpe));
|
||||
|
|
|
@ -122,7 +122,7 @@ TEST(schedulingExample, test) {
|
|||
|
||||
// Do exact inference
|
||||
gttic(small);
|
||||
auto MPE = s.optimalAssignment();
|
||||
auto MPE = s.optimize();
|
||||
gttoc(small);
|
||||
|
||||
// print MPE, commented out as unit tests don't print
|
||||
|
|
|
@ -100,7 +100,7 @@ class Sudoku : public CSP {
|
|||
|
||||
/// solve and print solution
|
||||
void printSolution() const {
|
||||
auto MPE = optimalAssignment();
|
||||
auto MPE = optimize();
|
||||
printAssignment(MPE);
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,7 @@ TEST(Sudoku, small) {
|
|||
0, 1, 0, 0);
|
||||
|
||||
// optimize and check
|
||||
auto solution = csp.optimalAssignment();
|
||||
auto solution = csp.optimize();
|
||||
DiscreteValues expected;
|
||||
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
|
||||
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
|
||||
|
@ -148,7 +148,7 @@ TEST(Sudoku, small) {
|
|||
EXPECT_LONGS_EQUAL(16, new_csp.size());
|
||||
|
||||
// Check that solution
|
||||
auto new_solution = new_csp.optimalAssignment();
|
||||
auto new_solution = new_csp.optimize();
|
||||
// csp.printAssignment(new_solution);
|
||||
EXPECT(assert_equal(expected, new_solution));
|
||||
}
|
||||
|
@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) {
|
|||
EXPECT_LONGS_EQUAL(81, new_csp.size());
|
||||
|
||||
// Check that solution
|
||||
auto solution = new_csp.optimalAssignment();
|
||||
auto solution = new_csp.optimize();
|
||||
// csp.printAssignment(solution);
|
||||
EXPECT_LONGS_EQUAL(6, solution.at(key99));
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
self.gtsamAssertEquals(chordal.at(7), expected2)
|
||||
|
||||
# solve
|
||||
actualMPE = chordal.optimize()
|
||||
actualMPE = fg.optimize()
|
||||
expectedMPE = DiscreteValues()
|
||||
for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]:
|
||||
expectedMPE[key[0]] = 0
|
||||
|
@ -94,8 +94,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
fg.add(Dyspnea, "0 1")
|
||||
|
||||
# solve again, now with evidence
|
||||
chordal2 = fg.eliminateSequential(ordering)
|
||||
actualMPE2 = chordal2.optimize()
|
||||
actualMPE2 = fg.optimize()
|
||||
expectedMPE2 = DiscreteValues()
|
||||
for key in [XRay, Tuberculosis, Either, LungCancer]:
|
||||
expectedMPE2[key[0]] = 0
|
||||
|
@ -105,6 +104,7 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
list(expectedMPE2.items()))
|
||||
|
||||
# now sample from it
|
||||
chordal2 = fg.eliminateSequential(ordering)
|
||||
actualSample = chordal2.sample()
|
||||
self.assertEqual(len(actualSample), 8)
|
||||
|
||||
|
@ -122,10 +122,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
for key in [Asia, Smoking]:
|
||||
given[key[0]] = 0
|
||||
|
||||
# Now optimize fragment:
|
||||
actual = fragment.optimize(given)
|
||||
self.assertEqual(len(actual), 5)
|
||||
|
||||
# Now sample from fragment:
|
||||
actual = fragment.sample(given)
|
||||
self.assertEqual(len(actual), 5)
|
||||
|
|
|
@ -20,7 +20,7 @@ from gtsam.utils.test_case import GtsamTestCase
|
|||
X = 0, 2
|
||||
|
||||
|
||||
class TestDiscretePrior(GtsamTestCase):
|
||||
class TestDiscreteDistribution(GtsamTestCase):
|
||||
"""Tests for Discrete Priors."""
|
||||
|
||||
def test_constructor(self):
|
||||
|
|
Loading…
Reference in New Issue