New, more powerful choose, yields a Conditional now
parent
2413fcb91f
commit
4a10ea89a5
|
@ -149,61 +149,58 @@ void DiscreteConditional::print(const string& s,
|
|||
|
||||
/* ******************************************************************************** */
|
||||
bool DiscreteConditional::equals(const DiscreteFactor& other,
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other))
|
||||
double tol) const {
|
||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||
return false;
|
||||
else {
|
||||
const DecisionTreeFactor& f(
|
||||
static_cast<const DecisionTreeFactor&>(other));
|
||||
} else {
|
||||
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
|
||||
return DecisionTreeFactor::equals(f, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
|
||||
const DiscreteValues& parentsValues) {
|
||||
const DiscreteValues& given,
|
||||
bool forceComplete = true) {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
DiscreteConditional::ADT adt(conditional);
|
||||
size_t value;
|
||||
for (Key j : conditional.parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
value = given.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (std::out_of_range&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
if (forceComplete) {
|
||||
given.print("parentsValues: ");
|
||||
throw runtime_error(
|
||||
"DiscreteConditional::Choose: parent value missing");
|
||||
}
|
||||
}
|
||||
}
|
||||
return adt;
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& parentsValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
// branches based on the value of the parent variables.
|
||||
ADT adt(*this);
|
||||
size_t value;
|
||||
for (Key j : parents()) {
|
||||
try {
|
||||
value = parentsValues.at(j);
|
||||
adt = adt.choose(j, value); // ADT keeps getting smaller.
|
||||
} catch (exception&) {
|
||||
parentsValues.print("parentsValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: parent value missing");
|
||||
};
|
||||
}
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::shared_ptr DiscreteConditional::choose(
|
||||
const DiscreteValues& given) const {
|
||||
ADT adt = Choose(*this, given, false); // P(F|S=given)
|
||||
|
||||
// Convert ADT to factor.
|
||||
DiscreteKeys discreteKeys;
|
||||
// Collect all keys not in given.
|
||||
DiscreteKeys dKeys;
|
||||
for (Key j : frontals()) {
|
||||
discreteKeys.emplace_back(j, this->cardinality(j));
|
||||
dKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
|
||||
for (size_t i = nrFrontals(); i < size(); i++) {
|
||||
Key j = keys_[i];
|
||||
if (given.count(j) == 0) {
|
||||
dKeys.emplace_back(j, this->cardinality(j));
|
||||
}
|
||||
}
|
||||
return boost::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
const DiscreteValues& frontalValues) const {
|
||||
// Get the big decision tree with all the levels, and then go down the
|
||||
|
@ -217,7 +214,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
} catch (exception&) {
|
||||
frontalValues.print("frontalValues: ");
|
||||
throw runtime_error("DiscreteConditional::choose: frontal value missing");
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Convert ADT to factor.
|
||||
|
@ -242,7 +239,6 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
|
||||
// TODO(Abhijit): is this really the fastest way? He thinks it is.
|
||||
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
|
||||
|
||||
// Initialize
|
||||
|
@ -276,11 +272,9 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
|||
(*values)[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
/* ************************************************************************** */
|
||||
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
|
||||
|
||||
// TODO: is this really the fastest way? I think it is.
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
|
||||
|
||||
// Then, find the max over all remaining
|
||||
// TODO, only works for one key now, seems horribly slow this way
|
||||
|
|
|
@ -157,9 +157,20 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
return ADT::operator()(values);
|
||||
}
|
||||
|
||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||
DecisionTreeFactor::shared_ptr choose(
|
||||
const DiscreteValues& parentsValues) const;
|
||||
/**
|
||||
* @brief restrict to given *parent* values.
|
||||
*
|
||||
* Note: does not need be complete set. Examples:
|
||||
*
|
||||
* P(C|D,E) + . -> P(C|D,E)
|
||||
* P(C|D,E) + E -> P(C|D)
|
||||
* P(C|D,E) + D -> P(C|E)
|
||||
* P(C|D,E) + D,E -> P(C)
|
||||
* P(C|D,E) + C -> error!
|
||||
*
|
||||
* @return a shared_ptr to a new DiscreteConditional
|
||||
*/
|
||||
shared_ptr choose(const DiscreteValues& given) const;
|
||||
|
||||
/** Convert to a likelihood factor by providing value before bar. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(
|
||||
|
|
|
@ -107,8 +107,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
|||
void printSignature(
|
||||
string s = "Discrete Conditional: ",
|
||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||
gtsam::DecisionTreeFactor* choose(
|
||||
const gtsam::DiscreteValues& parentsValues) const;
|
||||
gtsam::DecisionTreeFactor* choose(const gtsam::DiscreteValues& given) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(
|
||||
const gtsam::DiscreteValues& frontalValues) const;
|
||||
gtsam::DecisionTreeFactor* likelihood(size_t value) const;
|
||||
|
|
|
@ -221,6 +221,34 @@ TEST(DiscreteConditional, likelihood) {
|
|||
EXPECT(assert_equal(expected1, *actual1, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check choose on P(C|D,E)
|
||||
TEST(DiscreteConditional, choose) {
|
||||
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
// Case 1: no given values: no-op
|
||||
DiscreteValues given;
|
||||
auto actual1 = C_given_DE.choose(given);
|
||||
EXPECT(assert_equal(C_given_DE, *actual1, 1e-9));
|
||||
|
||||
// Case 2: 1 given value
|
||||
given[D.first] = 1;
|
||||
auto actual2 = C_given_DE.choose(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual2->nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(1, actual2->nrParents());
|
||||
DiscreteConditional expected2(C | E = "1/1 1/4");
|
||||
EXPECT(assert_equal(expected2, *actual2, 1e-9));
|
||||
|
||||
// Case 2: 2 given values
|
||||
given[E.first] = 0;
|
||||
auto actual3 = C_given_DE.choose(given);
|
||||
EXPECT_LONGS_EQUAL(1, actual3->nrFrontals());
|
||||
EXPECT_LONGS_EQUAL(0, actual3->nrParents());
|
||||
DiscreteConditional expected3(C % "1/1");
|
||||
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents.
|
||||
TEST(DiscreteConditional, markdown_prior) {
|
||||
|
|
Loading…
Reference in New Issue