made internal protected choose to avoid copy/paste in Lookup

release/4.3a0
Frank Dellaert 2022-01-21 14:26:35 -05:00
parent 2f49612b8c
commit e713897235
4 changed files with 41 additions and 73 deletions

View File

@ -16,26 +16,25 @@
* @author Frank Dellaert
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <boost/make_shared.hpp>
#include <algorithm>
#include <boost/make_shared.hpp>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <set>
#include <vector>
using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
using std::pair;
namespace gtsam {
// Instantiate base class
@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
cout << endl;
}
/* ******************************************************************************** */
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}
/* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& given,
bool forceComplete = true) {
DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, bool forceComplete) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
DiscreteConditional::ADT adt(*this);
size_t value;
for (Key j : conditional.parents()) {
for (Key j : parents()) {
try {
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::Choose: parent value missing");
"DiscreteConditional::choose: parent value missing");
}
}
}
@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given)
ADT adt = choose(given, false); // P(F|S=given)
// Collect all keys not in given.
DiscreteKeys dKeys;
@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}
/* ******************************************************************************** */
/* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const {
if (nrFrontals() != 1)
@ -240,7 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
/* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
@ -267,25 +265,24 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0;
DiscreteValues frontals;
size_t max = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
max = value;
}
}
return mpe;
return max;
}
#endif
@ -302,7 +299,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
// Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
@ -325,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}
/* ******************************************************************************** */
/* ********************************************************************************
*/
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
@ -336,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values);
}
/* ******************************************************************************** */
/* ********************************************************************************
*/
size_t DiscreteConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(

View File

@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}
/**
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
*/
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
/**
/**
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
* Makes sure the keys are ordered as given. Does not check orderedKeys.
@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional
return ADT::operator()(values);
}
/**
/**
* @brief restrict to given *parent* values.
*
*
* Note: does not need be complete set. Examples:
*
*
* P(C|D,E) + . -> P(C|D,E)
* P(C|D,E) + E -> P(C|D)
* P(C|D,E) + D -> P(C|E)
* P(C|D,E) + D,E -> P(C)
* P(C|D,E) + C -> error!
*
*
* @return a shared_ptr to a new DiscreteConditional
*/
shared_ptr choose(const DiscreteValues& given) const;
@ -226,6 +226,11 @@ class GTSAM_EXPORT DiscreteConditional
void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const;
/// @}
#endif
protected:
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
};
// DiscreteConditional

View File

@ -49,42 +49,9 @@ void DiscreteLookupTable::print(const std::string& s,
cout << endl;
}
/* ************************************************************************* */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
vector<DiscreteValues> DiscreteLookupTable::frontalAssignments() const {
vector<pair<Key, size_t>> pairs;
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
return DiscreteValues::CartesianProduct(rpairs);
}
/* ************************************************************************** */
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional,
const DiscreteValues& given,
bool forceComplete = true) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteLookupTable::ADT adt(conditional);
size_t value;
for (Key j : conditional.parents()) {
try {
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
if (forceComplete) {
given.print("parentsValues: ");
throw std::runtime_error(
"DiscreteLookupTable::Choose: parent value missing");
}
}
}
return adt;
}
/* ************************************************************************** */
void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
@ -111,13 +78,13 @@ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const {
/* ************************************************************************** */
size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Then, find the max over all remaining
// TODO(Duy): only works for one key now, seems horribly slow this way
size_t mpe = 0;
DiscreteValues frontals;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {

View File

@ -68,9 +68,6 @@ class DiscreteLookupTable : public DiscreteConditional {
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
/// Return all assignments for frontal variables.
std::vector<DiscreteValues> frontalAssignments() const;
};
/** A DAG made from lookup tables, as defined above. */