made internal protected choose to avoid copy/paste in Lookup
parent
2f49612b8c
commit
e713897235
|
@ -16,26 +16,25 @@
|
|||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/debug.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/make_shared.hpp>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using std::pair;
|
||||
using std::stringstream;
|
||||
using std::vector;
|
||||
using std::pair;
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
|
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
|
|||
cout << endl;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
|
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
|
|||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||
const DiscreteValues& given,
|
||||
bool forceComplete = true) {
|
||||
DiscreteConditional::ADT DiscreteConditional::choose(
|
||||
const DiscreteValues& given, bool forceComplete) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
DiscreteConditional::ADT adt(conditional);
|
||||
DiscreteConditional::ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : conditional.parents()) {
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = given.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
|
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
|||
if (forceComplete) {
|
||||
given.print("parentsValues: ");
|
||||
throw runtime_error(
|
||||
"DiscreteConditional::Choose: parent value missing");
|
||||
"DiscreteConditional::choose: parent value missing");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
|||
/* ************************************************************************** */
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& given) const {
|
||||
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
||||
ADT adt = choose(given, false); // P(F|S=given)
|
||||
|
||||
// Collect all keys not in given.
|
||||
DiscreteKeys dKeys;
|
||||
|
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ****************************************************************************/
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t parent_value) const {
|
||||
if (nrFrontals() != 1)
|
||||
|
@ -240,7 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
/* ************************************************************************** */
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
|
@ -267,25 +265,24 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
|||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
DiscreteValues frontals;
|
||||
size_t max = 0;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
frontals[j] = value;
|
||||
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
// Update MPE solution if better
|
||||
// Update solution if better
|
||||
if (pValueS > maxP) {
|
||||
maxP = pValueS;
|
||||
mpe = value;
|
||||
max = value;
|
||||
}
|
||||
}
|
||||
return mpe;
|
||||
return max;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -302,7 +299,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
static mt19937 rng(2); // random number generator
|
||||
|
||||
// Get the correct conditional density
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
if (nrFrontals() != 1) {
|
||||
|
@ -325,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
|
|||
return distribution(rng);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ********************************************************************************
|
||||
*/
|
||||
size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||
if (nrParents() != 1)
|
||||
throw std::invalid_argument(
|
||||
|
@ -336,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
|||
return sample(values);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ********************************************************************************
|
||||
*/
|
||||
size_t DiscreteConditional::sample() const {
|
||||
if (nrParents() != 0)
|
||||
throw std::invalid_argument(
|
||||
|
|
|
@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
*/
|
||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||
const DecisionTreeFactor& marginal);
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||
|
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/**
|
||||
/**
|
||||
* @brief restrict to given *parent* values.
|
||||
*
|
||||
*
|
||||
* Note: does not need be complete set. Examples:
|
||||
*
|
||||
*
|
||||
* P(C|D,E) + . -> P(C|D,E)
|
||||
* P(C|D,E) + E -> P(C|D)
|
||||
* P(C|D,E) + D -> P(C|E)
|
||||
* P(C|D,E) + D,E -> P(C)
|
||||
* P(C|D,E) + C -> error!
|
||||
*
|
||||
*
|
||||
* @return a shared_ptr to a new DiscreteConditional
|
||||
*/
|
||||
shared_ptr choose(const DiscreteValues& given) const;
|
||||
|
@ -226,6 +226,11 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
|
||||
/// @}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
/// Internal version of choose
|
||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||
bool forceComplete) const;
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
|
|
|
@ -49,42 +49,9 @@ void DiscreteLookupTable::print(const std::string& s,
|
|||
cout << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
vector<DiscreteValues> DiscreteLookupTable::frontalAssignments() const {
|
||||
vector<pair<Key, size_t>> pairs;
|
||||
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
|
||||
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
return DiscreteValues::CartesianProduct(rpairs);
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional,
|
||||
const DiscreteValues& given,
|
||||
bool forceComplete = true) {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
DiscreteLookupTable::ADT adt(conditional);
|
||||
size_t value;
|
||||
for (Key j : conditional.parents()) {
|
||||
try {
|
||||
value = given.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (std::out_of_range&) {
|
||||
if (forceComplete) {
|
||||
given.print("parentsValues: ");
|
||||
throw std::runtime_error(
|
||||
"DiscreteLookupTable::Choose: parent value missing");
|
||||
}
|
||||
}
|
||||
}
|
||||
return adt;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
DiscreteValues mpe;
|
||||
|
@ -111,13 +78,13 @@ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
|
|||
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
size_t mpe = 0;
|
||||
DiscreteValues frontals;
|
||||
double maxP = 0;
|
||||
DiscreteValues frontals;
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
for (size_t value = 0; value < cardinality(j); value++) {
|
||||
|
|
|
@ -68,9 +68,6 @@ class DiscreteLookupTable : public DiscreteConditional {
|
|||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||
|
||||
/// Return all assignments for frontal variables.
|
||||
std::vector<DiscreteValues> frontalAssignments() const;
|
||||
};
|
||||
|
||||
/** A DAG made from lookup tables, as defined above. */
|
||||
|
|
Loading…
Reference in New Issue