New, more powerful choose, yields a Conditional now
parent
2413fcb91f
commit
4a10ea89a5
|
@ -150,60 +150,57 @@ void DiscreteConditional::print(const string& s,
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||||
double tol) const {
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
else {
|
} else {
|
||||||
const DecisionTreeFactor& f(
|
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
static_cast<const DecisionTreeFactor&>(other));
|
|
||||||
return DecisionTreeFactor::equals(f, tol);
|
return DecisionTreeFactor::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||||
const DiscreteValues& parentsValues) {
|
const DiscreteValues& given,
|
||||||
|
bool forceComplete = true) {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
// branches based on the value of the parent variables.
|
// branches based on the value of the parent variables.
|
||||||
DiscreteConditional::ADT adt(conditional);
|
DiscreteConditional::ADT adt(conditional);
|
||||||
size_t value;
|
size_t value;
|
||||||
for (Key j : conditional.parents()) {
|
for (Key j : conditional.parents()) {
|
||||||
try {
|
try {
|
||||||
value = parentsValues.at(j);
|
value = given.at(j);
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||||
} catch (std::out_of_range&) {
|
} catch (std::out_of_range&) {
|
||||||
parentsValues.print("parentsValues: ");
|
if (forceComplete) {
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
given.print("parentsValues: ");
|
||||||
};
|
throw runtime_error(
|
||||||
|
"DiscreteConditional::Choose: parent value missing");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return adt;
|
return adt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||||
const DiscreteValues& parentsValues) const {
|
const DiscreteValues& given) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
||||||
// branches based on the value of the parent variables.
|
|
||||||
ADT adt(*this);
|
|
||||||
size_t value;
|
|
||||||
for (Key j : parents()) {
|
|
||||||
try {
|
|
||||||
value = parentsValues.at(j);
|
|
||||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
|
||||||
} catch (exception&) {
|
|
||||||
parentsValues.print("parentsValues: ");
|
|
||||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Collect all keys not in given.
|
||||||
DiscreteKeys discreteKeys;
|
DiscreteKeys dKeys;
|
||||||
for (Key j : frontals()) {
|
for (Key j : frontals()) {
|
||||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||||
|
Key j = keys_[i];
|
||||||
|
if (given.count(j) == 0) {
|
||||||
|
dKeys.emplace_back(j, this->cardinality(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
const DiscreteValues& frontalValues) const {
|
const DiscreteValues& frontalValues) const {
|
||||||
// Get the big decision tree with all the levels, and then go down the
|
// Get the big decision tree with all the levels, and then go down the
|
||||||
|
@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
} catch (exception&) {
|
} catch (exception&) {
|
||||||
frontalValues.print("frontalValues: ");
|
frontalValues.print("frontalValues: ");
|
||||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||||
};
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert ADT to factor.
|
// Convert ADT to factor.
|
||||||
|
@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||||
// TODO(Abhijit): is this really the fastest way? He thinks it is.
|
|
||||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Initialize
|
// Initialize
|
||||||
|
@ -276,10 +272,8 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||||
|
|
||||||
// TODO: is this really the fastest way? I think it is.
|
|
||||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||||
|
|
||||||
// Then, find the max over all remaining
|
// Then, find the max over all remaining
|
||||||
|
|
|
@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
/**
|
||||||
DecisionTreeFactor::shared_ptr choose(
|
* @brief restrict to given *parent* values.
|
||||||
const DiscreteValues& parentsValues) const;
|
*
|
||||||
|
* Note: does not need be complete set. Examples:
|
||||||
|
*
|
||||||
|
* P(C|D,E) + . -> P(C|D,E)
|
||||||
|
* P(C|D,E) + E -> P(C|D)
|
||||||
|
* P(C|D,E) + D -> P(C|E)
|
||||||
|
* P(C|D,E) + D,E -> P(C)
|
||||||
|
* P(C|D,E) + C -> error!
|
||||||
|
*
|
||||||
|
* @return a shared_ptr to a new DiscreteConditional
|
||||||
|
*/
|
||||||
|
shared_ptr choose(const DiscreteValues& given) const;
|
||||||
|
|
||||||
/** Convert to a likelihood factor by providing value before bar. */
|
/** Convert to a likelihood factor by providing value before bar. */
|
||||||
DecisionTreeFactor::shared_ptr likelihood(
|
DecisionTreeFactor::shared_ptr likelihood(
|
||||||
|
|
|
@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
void printSignature(
|
void printSignature(
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
gtsam::DecisionTreeFactor* choose(
|
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||||
const gtsam::DiscreteValues& parentsValues) const;
|
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
const gtsam::DiscreteValues& frontalValues) const;
|
const gtsam::DiscreteValues& frontalValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||||
|
|
|
@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) {
|
||||||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check choose on P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, choose) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// Case 1: no given values: no-op
|
||||||
|
DiscreteValues given;
|
||||||
|
auto actual1 = C_given_DE.choose(given);
|
||||||
|
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 1 given value
|
||||||
|
given[D.first] = 1;
|
||||||
|
auto actual2 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||||
|
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||||
|
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||||
|
|
||||||
|
// Case 2: 2 given values
|
||||||
|
given[E.first] = 0;
|
||||||
|
auto actual3 = C_given_DE.choose(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||||
|
DiscreteConditional expected3(C % "1/1");
|
||||||
|
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected, no parents.
|
// Check markdown representation looks as expected, no parents.
|
||||||
TEST(DiscreteConditional, markdown_prior) {
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
|
Loading…
Reference in New Issue