made internal protected choose to avoid copy/paste in Lookup

release/4.3a0
Frank Dellaert 2022-01-21 14:26:35 -05:00
parent 2f49612b8c
commit e713897235
4 changed files with 41 additions and 73 deletions

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm> #include <algorithm>
#include <boost/make_shared.hpp>
#include <random> #include <random>
#include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <set> #include <vector>
using namespace std; using namespace std;
using std::pair;
using std::stringstream; using std::stringstream;
using std::vector; using std::vector;
using std::pair;
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
cout << endl; cout << endl;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other, bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) { if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
} }
/* ************************************************************************** */ /* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, const DiscreteValues& given, bool forceComplete) const {
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the // Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables. // branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional); DiscreteConditional::ADT adt(*this);
size_t value; size_t value;
for (Key j : conditional.parents()) { for (Key j : parents()) {
try { try {
value = given.at(j); value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller. adt = adt.choose(j, value); // ADT keeps getting smaller.
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
if (forceComplete) { if (forceComplete) {
given.print("parentsValues: "); given.print("parentsValues: ");
throw runtime_error( throw runtime_error(
"DiscreteConditional::Choose: parent value missing"); "DiscreteConditional::choose: parent value missing");
} }
} }
} }
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose( DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const { const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given) ADT adt = choose(given, false); // P(F|S=given)
// Collect all keys not in given. // Collect all keys not in given.
DiscreteKeys dKeys; DiscreteKeys dKeys;
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt); return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
} }
/* ******************************************************************************** */ /* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const { size_t parent_value) const {
if (nrFrontals() != 1) if (nrFrontals() != 1)
@ -240,7 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
/* ************************************************************************** */ /* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const { void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -267,25 +265,24 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way size_t max = 0;
size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0; double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value; frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues) double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better // Update solution if better
if (pValueS > maxP) { if (pValueS > maxP) {
maxP = pValueS; maxP = pValueS;
mpe = value; max = value;
} }
} }
return mpe; return max;
} }
#endif #endif
@ -302,7 +299,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator static mt19937 rng(2); // random number generator
// Get the correct conditional density // Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) { if (nrFrontals() != 1) {
@ -325,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng); return distribution(rng);
} }
/* ******************************************************************************** */ /* ********************************************************************************
*/
size_t DiscreteConditional::sample(size_t parent_value) const { size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1) if (nrParents() != 1)
throw std::invalid_argument( throw std::invalid_argument(
@ -336,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values); return sample(values);
} }
/* ******************************************************************************** */ /* ********************************************************************************
*/
size_t DiscreteConditional::sample() const { size_t DiscreteConditional::sample() const {
if (nrParents() != 0) if (nrParents() != 0)
throw std::invalid_argument( throw std::invalid_argument(

View File

@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional(const DiscreteKey& key, const std::string& spec) DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {} : DiscreteConditional(Signature(key, {}, spec)) {}
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/ */
DiscreteConditional(const DecisionTreeFactor& joint, DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal); const DecisionTreeFactor& marginal);
/** /**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys. * Makes sure the keys are ordered as given. Does not check orderedKeys.
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values); return ADT::operator()(values);
} }
/** /**
* @brief restrict to given *parent* values. * @brief restrict to given *parent* values.
* *
* Note: does not need be complete set. Examples: * Note: does not need be complete set. Examples:
* *
* P(C|D,E) + . -> P(C|D,E) * P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D) * P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E) * P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C) * P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error! * P(C|D,E) + C -> error!
* *
* @return a shared_ptr to a new DiscreteConditional * @return a shared_ptr to a new DiscreteConditional
*/ */
shared_ptr choose(const DiscreteValues& given) const; shared_ptr choose(const DiscreteValues& given) const;
@ -226,6 +226,11 @@ class GTSAM_EXPORT DiscreteConditional
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @} /// @}
#endif #endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -49,42 +49,9 @@ void DiscreteLookupTable::print(const std::string& s,
cout << endl; cout << endl;
} }
/* ************************************************************************* */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
vector<DiscreteValues> DiscreteLookupTable::frontalAssignments() const {
vector<pair<Key, size_t>> pairs;
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
return DiscreteValues::CartesianProduct(rpairs);
}
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional,
const DiscreteValues& given,
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteLookupTable::ADT adt(conditional);
size_t value;
for (Key j : conditional.parents()) {
try {
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
if (forceComplete) {
given.print("parentsValues: ");
throw std::runtime_error(
"DiscreteLookupTable::Choose: parent value missing");
}
}
}
return adt;
}
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize // Initialize
DiscreteValues mpe; DiscreteValues mpe;
@ -111,13 +78,13 @@ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining // Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way // TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0; size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0; double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1); assert(nrFrontals() == 1);
Key j = (firstFrontalKey()); Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) { for (size_t value = 0; value < cardinality(j); value++) {

View File

@ -68,9 +68,6 @@ class DiscreteLookupTable : public DiscreteConditional {
* @param (in/out) parentsValues Known assignments for the parents. * @param (in/out) parentsValues Known assignments for the parents.
*/ */
void argmaxInPlace(DiscreteValues* parentsValues) const; void argmaxInPlace(DiscreteValues* parentsValues) const;
/// Return all assignments for frontal variables.
std::vector<DiscreteValues> frontalAssignments() const;
}; };
/** A DAG made from lookup tables, as defined above. */ /** A DAG made from lookup tables, as defined above. */