Added code to re-jigger Signature cpt so that frontal keys are always first, consistent with how the DiscreteElimination function works.
parent
9c5bba753c
commit
1ffddf72e1
|
@ -61,10 +61,9 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const Signature& signature) :
|
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||||
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
|
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||||
1) {
|
BaseConditional(1) {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ******************************************************************************** */
|
||||||
void DiscreteConditional::print(const std::string& s,
|
void DiscreteConditional::print(const std::string& s,
|
||||||
|
|
|
@ -90,22 +90,6 @@ public:
|
||||||
/// GTSAM-style equals
|
/// GTSAM-style equals
|
||||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/** Iterator pointing to first frontal key. */
|
|
||||||
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
|
|
||||||
|
|
||||||
/** Iterator pointing past the last frontal key. */
|
|
||||||
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
|
|
||||||
|
|
||||||
/** Iterator pointing to the first parent key. */
|
|
||||||
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
|
|
||||||
|
|
||||||
/** Iterator pointing past the last parent key. */
|
|
||||||
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -122,28 +122,30 @@ namespace gtsam {
|
||||||
key_(key) {
|
key_(key) {
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteKeys Signature::discreteKeysParentsFirst() const {
|
DiscreteKeys Signature::discreteKeys() const {
|
||||||
DiscreteKeys keys;
|
DiscreteKeys keys;
|
||||||
for(const DiscreteKey& key: parents_)
|
|
||||||
keys.push_back(key);
|
|
||||||
keys.push_back(key_);
|
keys.push_back(key_);
|
||||||
|
for (const DiscreteKey& key : parents_) keys.push_back(key);
|
||||||
return keys;
|
return keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
KeyVector Signature::indices() const {
|
KeyVector Signature::indices() const {
|
||||||
KeyVector js;
|
KeyVector js;
|
||||||
js.push_back(key_.first);
|
js.push_back(key_.first);
|
||||||
for(const DiscreteKey& key: parents_)
|
for (const DiscreteKey& key : parents_) js.push_back(key.first);
|
||||||
js.push_back(key.first);
|
|
||||||
return js;
|
return js;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<double> Signature::cpt() const {
|
vector<double> Signature::cpt() const {
|
||||||
vector<double> cpt;
|
vector<double> cpt;
|
||||||
if (table_) {
|
if (table_) {
|
||||||
for(const Row& row: *table_)
|
const size_t nrStates = table_->at(0).size();
|
||||||
for(const double& x: row)
|
for (size_t j = 0; j < nrStates; j++) {
|
||||||
cpt.push_back(x);
|
for (const Row& row : *table_) {
|
||||||
|
assert(row.size() == nrStates);
|
||||||
|
cpt.push_back(row[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return cpt;
|
return cpt;
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,8 +86,8 @@ namespace gtsam {
|
||||||
return parents_;
|
return parents_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** All keys, with variable key last */
|
/** All keys, with variable key first */
|
||||||
DiscreteKeys discreteKeysParentsFirst() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
/** All key indices, with variable key first */
|
/** All key indices, with variable key first */
|
||||||
KeyVector indices() const;
|
KeyVector indices() const;
|
||||||
|
|
|
@ -132,7 +132,7 @@ TEST(ADT, example3)
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
/** Convert Signature into CPT */
|
||||||
ADT create(const Signature& signature) {
|
ADT create(const Signature& signature) {
|
||||||
ADT p(signature.discreteKeysParentsFirst(), signature.cpt());
|
ADT p(signature.discreteKeys(), signature.cpt());
|
||||||
static size_t count = 0;
|
static size_t count = 0;
|
||||||
const DiscreteKey& key = signature.key();
|
const DiscreteKey& key = signature.key();
|
||||||
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||||
|
@ -181,19 +181,20 @@ TEST(ADT, joint)
|
||||||
dot(joint, "Asia-ASTLBEX");
|
dot(joint, "Asia-ASTLBEX");
|
||||||
joint = apply(joint, pD, &mul);
|
joint = apply(joint, pD, &mul);
|
||||||
dot(joint, "Asia-ASTLBEXD");
|
dot(joint, "Asia-ASTLBEXD");
|
||||||
EXPECT_LONGS_EQUAL(346, (long)muls);
|
EXPECT_LONGS_EQUAL(346, muls);
|
||||||
gttoc_(asiaJoint);
|
gttoc_(asiaJoint);
|
||||||
tictoc_getNode(asiaJointNode, asiaJoint);
|
tictoc_getNode(asiaJointNode, asiaJoint);
|
||||||
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
||||||
tictoc_reset_();
|
tictoc_reset_();
|
||||||
printCounts("Asia joint");
|
printCounts("Asia joint");
|
||||||
|
|
||||||
|
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
|
||||||
ADT pASTL = pA;
|
ADT pASTL = pA;
|
||||||
pASTL = apply(pASTL, pS, &mul);
|
pASTL = apply(pASTL, pS, &mul);
|
||||||
pASTL = apply(pASTL, pT, &mul);
|
pASTL = apply(pASTL, pT, &mul);
|
||||||
pASTL = apply(pASTL, pL, &mul);
|
pASTL = apply(pASTL, pL, &mul);
|
||||||
|
|
||||||
// test combine
|
// test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
|
||||||
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
|
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
|
||||||
EXPECT(assert_equal(pA, fAa));
|
EXPECT(assert_equal(pA, fAa));
|
||||||
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);
|
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);
|
||||||
|
|
|
@ -38,8 +38,8 @@ TEST( DiscreteConditional, constructors)
|
||||||
EXPECT(expected1);
|
EXPECT(expected1);
|
||||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||||
EXPECT(expected1->endParents() == expected1->beginFrontals());
|
EXPECT(expected1->endParents() == expected1->end());
|
||||||
EXPECT(expected1->endFrontals() == expected1->end());
|
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
||||||
|
|
||||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
DiscreteConditional actual1(1, f1);
|
DiscreteConditional actual1(1, f1);
|
||||||
|
|
|
@ -11,17 +11,18 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file testSignature
|
* @file testSignature
|
||||||
* @brief Tests focusing on the details of Signatures to evaluate boost compliance
|
* @brief Tests focusing on the details of Signatures to evaluate boost
|
||||||
|
* compliance
|
||||||
* @author Alex Cunningham
|
* @author Alex Cunningham
|
||||||
* @date Sept 19th 2011
|
* @date Sept 19th 2011
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/assign/std/vector.hpp>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
@ -32,15 +33,15 @@ DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||||
TEST(testSignature, simple_conditional) {
|
TEST(testSignature, simple_conditional) {
|
||||||
Signature sig(X | Y = "1/1 2/3 1/4");
|
Signature sig(X | Y = "1/1 2/3 1/4");
|
||||||
DiscreteKey actKey = sig.key();
|
DiscreteKey actKey = sig.key();
|
||||||
LONGS_EQUAL((long)X.first, (long)actKey.first);
|
LONGS_EQUAL(X.first, actKey.first);
|
||||||
|
|
||||||
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
|
DiscreteKeys actKeys = sig.discreteKeys();
|
||||||
LONGS_EQUAL(2, (long)actKeys.size());
|
LONGS_EQUAL(2, actKeys.size());
|
||||||
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
|
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||||
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
|
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||||
|
|
||||||
vector<double> actCpt = sig.cpt();
|
vector<double> actCpt = sig.cpt();
|
||||||
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
|
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -54,17 +55,20 @@ TEST(testSignature, simple_conditional_nonparser) {
|
||||||
|
|
||||||
Signature sig(X | Y = table);
|
Signature sig(X | Y = table);
|
||||||
DiscreteKey actKey = sig.key();
|
DiscreteKey actKey = sig.key();
|
||||||
EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first);
|
EXPECT_LONGS_EQUAL(X.first, actKey.first);
|
||||||
|
|
||||||
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
|
DiscreteKeys actKeys = sig.discreteKeys();
|
||||||
LONGS_EQUAL(2, (long)actKeys.size());
|
LONGS_EQUAL(2, actKeys.size());
|
||||||
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
|
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||||
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
|
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||||
|
|
||||||
vector<double> actCpt = sig.cpt();
|
vector<double> actCpt = sig.cpt();
|
||||||
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
|
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue