Replace buggy/awkward Combine with principled operator*, remove toFactor
parent
0909e98389
commit
f9dd225ca5
|
@ -30,6 +30,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using std::stringstream;
|
using std::stringstream;
|
||||||
|
@ -38,37 +39,77 @@ using std::pair;
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class GTSAM_EXPORT Conditional<DecisionTreeFactor, DiscreteConditional> ;
|
template class GTSAM_EXPORT
|
||||||
|
Conditional<DecisionTreeFactor, DiscreteConditional>;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||||
const DecisionTreeFactor& f) :
|
const DecisionTreeFactor& f)
|
||||||
BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {
|
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
const DecisionTreeFactor& marginal) :
|
const DiscreteKeys& keys,
|
||||||
BaseFactor(
|
const ADT& potentials)
|
||||||
ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional(
|
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
|
||||||
joint.size()-marginal.size()) {
|
|
||||||
if (ISDEBUG("DiscreteConditional::DiscreteConditional"))
|
|
||||||
cout << (firstFrontalKey()) << endl; //TODO Print all keys
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal, const Ordering& orderedKeys) :
|
const DecisionTreeFactor& marginal)
|
||||||
DiscreteConditional(joint, marginal) {
|
: BaseFactor(joint / marginal),
|
||||||
|
BaseConditional(joint.size() - marginal.size()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
|
const DecisionTreeFactor& marginal,
|
||||||
|
const Ordering& orderedKeys)
|
||||||
|
: DiscreteConditional(joint, marginal) {
|
||||||
keys_.clear();
|
keys_.clear();
|
||||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||||
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||||
BaseConditional(1) {}
|
BaseConditional(1) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional DiscreteConditional::operator*(
|
||||||
|
const DiscreteConditional& other) const {
|
||||||
|
// Take union of frontal keys
|
||||||
|
std::set<Key> newFrontals;
|
||||||
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||||
|
for (auto&& key : other.frontals()) newFrontals.insert(key);
|
||||||
|
|
||||||
|
// Check if frontals overlapped
|
||||||
|
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::operator* called with overlapping frontal keys.");
|
||||||
|
|
||||||
|
// Now, add cardinalities.
|
||||||
|
DiscreteKeys discreteKeys;
|
||||||
|
for (auto&& key : frontals())
|
||||||
|
discreteKeys.emplace_back(key, cardinality(key));
|
||||||
|
for (auto&& key : other.frontals())
|
||||||
|
discreteKeys.emplace_back(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Sort
|
||||||
|
std::sort(discreteKeys.begin(), discreteKeys.end());
|
||||||
|
|
||||||
|
// Add parents to set, to make them unique
|
||||||
|
std::set<DiscreteKey> parents;
|
||||||
|
for (auto&& key : this->parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
|
||||||
|
for (auto&& key : other.parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Finally, add parents to keys, in order
|
||||||
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||||
|
|
||||||
|
ADT product = ADT::apply(other, ADT::Ring::mul);
|
||||||
|
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
|
||||||
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::print(const string& s,
|
void DiscreteConditional::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
|
@ -82,7 +123,7 @@ void DiscreteConditional::print(const string& s,
|
||||||
cout << formatter(*it) << " ";
|
cout << formatter(*it) << " ";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << ")";
|
cout << "):\n";
|
||||||
ADT::print("");
|
ADT::print("");
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** default constructor needed for serialization */
|
/// Default constructor needed for serialization.
|
||||||
DiscreteConditional() {}
|
DiscreteConditional() {}
|
||||||
|
|
||||||
/** constructor from factor */
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||||
|
* `nFrontals` keys as frontals, in the order given.
|
||||||
|
*/
|
||||||
|
DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const ADT& potentials);
|
||||||
|
|
||||||
/** Construct from signature */
|
/** Construct from signature */
|
||||||
DiscreteConditional(const Signature& signature);
|
explicit DiscreteConditional(const Signature& signature);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from key, parents, and a Signature::Table specifying the
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
|
@ -86,27 +93,38 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
: DiscreteConditional(Signature(key, {}, spec)) {}
|
: DiscreteConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal);
|
const DecisionTreeFactor& marginal);
|
||||||
|
|
||||||
/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
|
*/
|
||||||
DiscreteConditional(const DecisionTreeFactor& joint,
|
DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
const DecisionTreeFactor& marginal,
|
const DecisionTreeFactor& marginal,
|
||||||
const Ordering& orderedKeys);
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Combine several conditional into a single one.
|
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||||
* The conditionals must be given in increasing order, meaning that the
|
* of the frontal keys, ordered by gtsam::Key.
|
||||||
* parents of any conditional may not include a conditional coming before it.
|
*
|
||||||
* @param firstConditional Iterator to the first conditional to combine, must
|
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||||
* dereference to a shared_ptr<DiscreteConditional>.
|
* the frontal variables cannot overlap, and must be acyclic:
|
||||||
* @param lastConditional Iterator to after the last conditional to combine,
|
* Example of correct use:
|
||||||
* must dereference to a shared_ptr<DiscreteConditional>.
|
* P(A,B) = P(A|B) * P(B)
|
||||||
* */
|
* P(A,B|C) = P(A|B) * P(B|C)
|
||||||
template <typename ITERATOR>
|
* P(A,B,C) = P(A,B|C) * P(C)
|
||||||
static shared_ptr Combine(ITERATOR firstConditional,
|
* Example of incorrect use:
|
||||||
ITERATOR lastConditional);
|
* P(A|B) * P(A|C) = ?
|
||||||
|
* P(A|B) * P(B|A) = ?
|
||||||
|
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||||
|
*/
|
||||||
|
DiscreteConditional operator*(const DiscreteConditional& other) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -136,11 +154,6 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Convert to a factor */
|
|
||||||
DecisionTreeFactor::shared_ptr toFactor() const {
|
|
||||||
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Restrict to given parent values, returns DecisionTreeFactor */
|
/** Restrict to given parent values, returns DecisionTreeFactor */
|
||||||
DecisionTreeFactor::shared_ptr choose(
|
DecisionTreeFactor::shared_ptr choose(
|
||||||
const DiscreteValues& parentsValues) const;
|
const DiscreteValues& parentsValues) const;
|
||||||
|
@ -208,23 +221,4 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
template <>
|
template <>
|
||||||
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
struct traits<DiscreteConditional> : public Testable<DiscreteConditional> {};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
template <typename ITERATOR>
|
|
||||||
DiscreteConditional::shared_ptr DiscreteConditional::Combine(
|
|
||||||
ITERATOR firstConditional, ITERATOR lastConditional) {
|
|
||||||
// TODO: check for being a clique
|
|
||||||
|
|
||||||
// multiply all the potentials of the given conditionals
|
|
||||||
size_t nrFrontals = 0;
|
|
||||||
DecisionTreeFactor product;
|
|
||||||
for (ITERATOR it = firstConditional; it != lastConditional;
|
|
||||||
++it, ++nrFrontals) {
|
|
||||||
DiscreteConditional::shared_ptr c = *it;
|
|
||||||
DecisionTreeFactor::shared_ptr factor = c->toFactor();
|
|
||||||
product = (*factor) * product;
|
|
||||||
}
|
|
||||||
// and then create a new multi-frontal conditional
|
|
||||||
return boost::make_shared<DiscreteConditional>(nrFrontals, product);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -102,7 +102,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
void printSignature(
|
void printSignature(
|
||||||
string s = "Discrete Conditional: ",
|
string s = "Discrete Conditional: ",
|
||||||
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const;
|
||||||
gtsam::DecisionTreeFactor* toFactor() const;
|
|
||||||
gtsam::DecisionTreeFactor* choose(
|
gtsam::DecisionTreeFactor* choose(
|
||||||
const gtsam::DiscreteValues& parentsValues) const;
|
const gtsam::DiscreteValues& parentsValues) const;
|
||||||
gtsam::DecisionTreeFactor* likelihood(
|
gtsam::DecisionTreeFactor* likelihood(
|
||||||
|
|
|
@ -60,7 +60,7 @@ TEST(DecisionTreeFactor, multiplication) {
|
||||||
DiscretePrior prior(v1 % "1/3");
|
DiscretePrior prior(v1 % "1/3");
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
|
||||||
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
|
||||||
CHECK(assert_equal(expected, prior * f1));
|
CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * f1));
|
||||||
CHECK(assert_equal(expected, f1 * prior));
|
CHECK(assert_equal(expected, f1 * prior));
|
||||||
|
|
||||||
// Multiply two factors
|
// Multiply two factors
|
||||||
|
|
|
@ -34,20 +34,21 @@ using namespace gtsam;
|
||||||
TEST(DiscreteConditional, constructors) {
|
TEST(DiscreteConditional, constructors) {
|
||||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||||
|
|
||||||
DiscreteConditional expected(X | Y = "1/1 2/3 1/4");
|
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
|
||||||
EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals()));
|
EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals()));
|
||||||
EXPECT_LONGS_EQUAL(2, *(expected.beginParents()));
|
EXPECT_LONGS_EQUAL(2, *(actual.beginParents()));
|
||||||
EXPECT(expected.endParents() == expected.end());
|
EXPECT(actual.endParents() == actual.end());
|
||||||
EXPECT(expected.endFrontals() == expected.beginParents());
|
EXPECT(actual.endFrontals() == actual.beginParents());
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(expected, actual1, 1e-9));
|
EXPECT(assert_equal(expected1, actual, 1e-9));
|
||||||
|
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||||
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -61,6 +62,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
r3 += 1.0, 4.0;
|
r3 += 1.0, 4.0;
|
||||||
table += r1, r2, r3;
|
table += r1, r2, r3;
|
||||||
DiscreteConditional actual1(X, {Y}, table);
|
DiscreteConditional actual1(X, {Y}, table);
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional expected1(1, f1);
|
DiscreteConditional expected1(1, f1);
|
||||||
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
EXPECT(assert_equal(expected1, actual1, 1e-9));
|
||||||
|
@ -68,43 +70,109 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
||||||
DecisionTreeFactor f2(
|
DecisionTreeFactor f2(
|
||||||
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
|
||||||
DiscreteConditional actual2(1, f2);
|
DiscreteConditional actual2(1, f2);
|
||||||
EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9));
|
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
|
||||||
|
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, constructors2) {
|
TEST(DiscreteConditional, constructors2) {
|
||||||
// Declare keys and ordering
|
|
||||||
DiscreteKey C(0, 2), B(1, 2);
|
DiscreteKey C(0, 2), B(1, 2);
|
||||||
DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25");
|
|
||||||
Signature signature((C | B) = "4/1 3/1");
|
Signature signature((C | B) = "4/1 3/1");
|
||||||
DiscreteConditional expected(signature);
|
DiscreteConditional actual(signature);
|
||||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
|
||||||
EXPECT(assert_equal(*expectedFactor, actual));
|
DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25");
|
||||||
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, constructors3) {
|
TEST(DiscreteConditional, constructors3) {
|
||||||
// Declare keys and ordering
|
|
||||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||||
DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
|
||||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||||
DiscreteConditional expected(signature);
|
DiscreteConditional actual(signature);
|
||||||
DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor();
|
|
||||||
EXPECT(assert_equal(*expectedFactor, actual));
|
DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8");
|
||||||
|
EXPECT(assert_equal(expected, static_cast<DecisionTreeFactor>(actual)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, Combine) {
|
// Check calculation of joint P(A,B)
|
||||||
|
TEST(DiscreteConditional, Multiply) {
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
vector<DiscreteConditional::shared_ptr> c;
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
|
DiscreteConditional prior(B % "1/2");
|
||||||
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
|
|
||||||
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
|
|
||||||
DiscreteConditional expected(2, factor);
|
|
||||||
auto actual = DiscreteConditional::Combine(c.begin(), c.end());
|
|
||||||
EXPECT(assert_equal(expected, *actual, 1e-5));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// P(A,B) = P(A|B) * P(B) = P(B) * P(A|B)
|
||||||
|
for (auto&& actual : {prior * conditional, conditional * prior}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B|C)
|
||||||
|
TEST(DiscreteConditional, Multiply2) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||||
|
|
||||||
|
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||||
|
TEST(DiscreteConditional, Multiply3) {
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||||
|
|
||||||
|
// P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B)
|
||||||
|
for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) {
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{1, 2}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||||
|
TEST(DiscreteConditional, Multiply4) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
|
||||||
|
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||||
|
DiscreteConditional B_given_D(B | D = "1/3 3/1");
|
||||||
|
DiscreteConditional AB_given_D = A_given_B * B_given_D;
|
||||||
|
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||||
|
|
||||||
|
// P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D)
|
||||||
|
for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) {
|
||||||
|
EXPECT_LONGS_EQUAL(3, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrParents());
|
||||||
|
KeyVector frontals(actual.beginFrontals(), actual.endFrontals());
|
||||||
|
EXPECT((frontals == KeyVector{0, 1, 2}));
|
||||||
|
KeyVector parents(actual.beginParents(), actual.endParents());
|
||||||
|
EXPECT((parents == KeyVector{3, 4}));
|
||||||
|
for (auto&& it : actual.enumerate()) {
|
||||||
|
const DiscreteValues& v = it.first;
|
||||||
|
EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscreteConditional, likelihood) {
|
TEST(DiscreteConditional, likelihood) {
|
||||||
DiscreteKey X(0, 2), Y(1, 3);
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
|
@ -42,6 +42,19 @@ TEST(DiscretePrior, constructors) {
|
||||||
EXPECT(assert_equal(expected, actual2, 1e-9));
|
EXPECT(assert_equal(expected, actual2, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscretePrior, Multiply) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||||
|
DiscretePrior prior(B, "1/2");
|
||||||
|
DiscreteConditional actual = prior * conditional; // P(A|B) * P(B)
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B)
|
||||||
|
DecisionTreeFactor factor(A & B, "1 4 2 2");
|
||||||
|
DiscreteConditional expected(2, factor);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, operator) {
|
TEST(DiscretePrior, operator) {
|
||||||
DiscretePrior prior(X % "2/3");
|
DiscretePrior prior(X % "2/3");
|
||||||
|
|
Loading…
Reference in New Issue