made internal protected choose to avoid copy/paste in Lookup
parent
2f49612b8c
commit
e713897235
|
@ -16,26 +16,25 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/base/debug.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <set>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
using std::pair;
|
||||||
using std::stringstream;
|
using std::stringstream;
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using std::pair;
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
|
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||||
const DiscreteValues& given,
|
const DiscreteValues& given, bool forceComplete) const {
|
||||||
bool forceComplete = true) {
|
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
// branches based on the value of the parent variables.
|
// branches based on the value of the parent variables.
|
||||||
DiscreteConditional::ADT adt(conditional);
|
DiscreteConditional::ADT adt(*this);
|
||||||
size_t value;
|
size_t value;
|
||||||
for (Key j : conditional.parents()) {
|
for (Key j : parents()) {
|
||||||
try {
|
try {
|
||||||
value = given.at(j);
|
value = given.at(j);
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
|
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
if (forceComplete) {
|
if (forceComplete) {
|
||||||
given.print("parentsValues: ");
|
given.print("parentsValues: ");
|
||||||
throw runtime_error(
|
throw runtime_error(
|
||||||
"DiscreteConditional::Choose: parent value missing");
|
"DiscreteConditional::choose: parent value missing");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||||
const DiscreteValues& given) const {
|
const DiscreteValues& given) const {
|
||||||
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
ADT adt = choose(given, false); // P(F|S=given)
|
||||||
|
|
||||||
// Collect all keys not in given.
|
// Collect all keys not in given.
|
||||||
DiscreteKeys dKeys;
|
DiscreteKeys dKeys;
|
||||||
|
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ****************************************************************************/
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
size_t parent_value) const {
|
size_t parent_value) const {
|
||||||
if (nrFrontals() != 1)
|
if (nrFrontals() != 1)
|
||||||
|
@ -240,7 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
@ -267,25 +265,24 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO, only works for one key now, seems horribly slow this way
|
size_t max = 0;
|
||||||
size_t mpe = 0;
|
|
||||||
DiscreteValues frontals;
|
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
frontals[j] = value;
|
frontals[j] = value;
|
||||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
// Update MPE solution if better
|
// Update solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
mpe = value;
|
max = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mpe;
|
return max;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -302,7 +299,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
static mt19937 rng(2); // random number generator
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
// Get the correct conditional density
|
// Get the correct conditional density
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
|
@ -325,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
||||||
return distribution(rng);
|
return distribution(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ********************************************************************************
|
||||||
|
*/
|
||||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
if (nrParents() != 1)
|
if (nrParents() != 1)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
@ -336,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
return sample(values);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ********************************************************************************
|
||||||
|
*/
|
||||||
size_t DiscreteConditional::sample() const {
|
size_t DiscreteConditional::sample() const {
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
|
|
@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
*/
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
|
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief restrict to given *parent* values.
|
* @brief restrict to given *parent* values.
|
||||||
*
|
*
|
||||||
* Note: does not need be complete set. Examples:
|
* Note: does not need be complete set. Examples:
|
||||||
*
|
*
|
||||||
* P(C|D,E) + . -> P(C|D,E)
|
* P(C|D,E) + . -> P(C|D,E)
|
||||||
* P(C|D,E) + E -> P(C|D)
|
* P(C|D,E) + E -> P(C|D)
|
||||||
* P(C|D,E) + D -> P(C|E)
|
* P(C|D,E) + D -> P(C|E)
|
||||||
* P(C|D,E) + D,E -> P(C)
|
* P(C|D,E) + D,E -> P(C)
|
||||||
* P(C|D,E) + C -> error!
|
* P(C|D,E) + C -> error!
|
||||||
*
|
*
|
||||||
* @return a shared_ptr to a new DiscreteConditional
|
* @return a shared_ptr to a new DiscreteConditional
|
||||||
*/
|
*/
|
||||||
shared_ptr choose(const DiscreteValues& given) const;
|
shared_ptr choose(const DiscreteValues& given) const;
|
||||||
|
@ -226,6 +226,11 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||||
/// @}
|
/// @}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// Internal version of choose
|
||||||
|
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||||
|
bool forceComplete) const;
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
|
|
|
@ -49,42 +49,9 @@ void DiscreteLookupTable::print(const std::string& s,
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
|
||||||
vector<DiscreteValues> DiscreteLookupTable::frontalAssignments() const {
|
|
||||||
vector<pair<Key, size_t>> pairs;
|
|
||||||
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
|
||||||
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
|
||||||
return DiscreteValues::CartesianProduct(rpairs);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
|
||||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
|
||||||
static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional,
|
|
||||||
const DiscreteValues& given,
|
|
||||||
bool forceComplete = true) {
|
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
|
||||||
// branches based on the value of the parent variables.
|
|
||||||
DiscreteLookupTable::ADT adt(conditional);
|
|
||||||
size_t value;
|
|
||||||
for (Key j : conditional.parents()) {
|
|
||||||
try {
|
|
||||||
value = given.at(j);
|
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
|
||||||
} catch (std::out_of_range&) {
|
|
||||||
if (forceComplete) {
|
|
||||||
given.print("parentsValues: ");
|
|
||||||
throw std::runtime_error(
|
|
||||||
"DiscreteLookupTable::Choose: parent value missing");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return adt;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
DiscreteValues mpe;
|
DiscreteValues mpe;
|
||||||
|
@ -111,13 +78,13 @@ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
size_t mpe = 0;
|
size_t mpe = 0;
|
||||||
DiscreteValues frontals;
|
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues frontals;
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
Key j = (firstFrontalKey());
|
Key j = (firstFrontalKey());
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
|
|
@ -68,9 +68,6 @@ class DiscreteLookupTable : public DiscreteConditional {
|
||||||
* @param (in/out) parentsValues Known assignments for the parents.
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
*/
|
*/
|
||||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
/// Return all assignments for frontal variables.
|
|
||||||
std::vector<DiscreteValues> frontalAssignments() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/** A DAG made from lookup tables, as defined above. */
|
/** A DAG made from lookup tables, as defined above. */
|
||||||
|
|
Loading…
Reference in New Issue