Added code to re-jigger Signature cpt so that frontal keys are always first, consistent with how the DiscreteElimination function works.
parent
9c5bba753c
commit
1ffddf72e1
|
@ -61,10 +61,9 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
|
|||
}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const Signature& signature) :
|
||||
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
|
||||
1) {
|
||||
}
|
||||
DiscreteConditional::DiscreteConditional(const Signature& signature)
|
||||
: BaseFactor(signature.discreteKeys(), signature.cpt()),
|
||||
BaseConditional(1) {}
|
||||
|
||||
/* ******************************************************************************** */
|
||||
void DiscreteConditional::print(const std::string& s,
|
||||
|
|
|
@ -90,22 +90,6 @@ public:
|
|||
/// GTSAM-style equals
|
||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
|
||||
/// @{
|
||||
|
||||
/** Iterator pointing to first frontal key. */
|
||||
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
|
||||
|
||||
/** Iterator pointing past the last frontal key. */
|
||||
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
|
||||
|
||||
/** Iterator pointing to the first parent key. */
|
||||
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
|
||||
|
||||
/** Iterator pointing past the last parent key. */
|
||||
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
|
|
@ -122,28 +122,30 @@ namespace gtsam {
|
|||
key_(key) {
|
||||
}
|
||||
|
||||
DiscreteKeys Signature::discreteKeysParentsFirst() const {
|
||||
DiscreteKeys Signature::discreteKeys() const {
|
||||
DiscreteKeys keys;
|
||||
for(const DiscreteKey& key: parents_)
|
||||
keys.push_back(key);
|
||||
keys.push_back(key_);
|
||||
for (const DiscreteKey& key : parents_) keys.push_back(key);
|
||||
return keys;
|
||||
}
|
||||
|
||||
KeyVector Signature::indices() const {
|
||||
KeyVector js;
|
||||
js.push_back(key_.first);
|
||||
for(const DiscreteKey& key: parents_)
|
||||
js.push_back(key.first);
|
||||
for (const DiscreteKey& key : parents_) js.push_back(key.first);
|
||||
return js;
|
||||
}
|
||||
|
||||
vector<double> Signature::cpt() const {
|
||||
vector<double> cpt;
|
||||
if (table_) {
|
||||
for(const Row& row: *table_)
|
||||
for(const double& x: row)
|
||||
cpt.push_back(x);
|
||||
const size_t nrStates = table_->at(0).size();
|
||||
for (size_t j = 0; j < nrStates; j++) {
|
||||
for (const Row& row : *table_) {
|
||||
assert(row.size() == nrStates);
|
||||
cpt.push_back(row[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return cpt;
|
||||
}
|
||||
|
|
|
@ -86,8 +86,8 @@ namespace gtsam {
|
|||
return parents_;
|
||||
}
|
||||
|
||||
/** All keys, with variable key last */
|
||||
DiscreteKeys discreteKeysParentsFirst() const;
|
||||
/** All keys, with variable key first */
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
/** All key indices, with variable key first */
|
||||
KeyVector indices() const;
|
||||
|
|
|
@ -132,7 +132,7 @@ TEST(ADT, example3)
|
|||
|
||||
/** Convert Signature into CPT */
|
||||
ADT create(const Signature& signature) {
|
||||
ADT p(signature.discreteKeysParentsFirst(), signature.cpt());
|
||||
ADT p(signature.discreteKeys(), signature.cpt());
|
||||
static size_t count = 0;
|
||||
const DiscreteKey& key = signature.key();
|
||||
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
|
||||
|
@ -181,19 +181,20 @@ TEST(ADT, joint)
|
|||
dot(joint, "Asia-ASTLBEX");
|
||||
joint = apply(joint, pD, &mul);
|
||||
dot(joint, "Asia-ASTLBEXD");
|
||||
EXPECT_LONGS_EQUAL(346, (long)muls);
|
||||
EXPECT_LONGS_EQUAL(346, muls);
|
||||
gttoc_(asiaJoint);
|
||||
tictoc_getNode(asiaJointNode, asiaJoint);
|
||||
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
||||
tictoc_reset_();
|
||||
printCounts("Asia joint");
|
||||
|
||||
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
|
||||
ADT pASTL = pA;
|
||||
pASTL = apply(pASTL, pS, &mul);
|
||||
pASTL = apply(pASTL, pT, &mul);
|
||||
pASTL = apply(pASTL, pL, &mul);
|
||||
|
||||
// test combine
|
||||
// test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
|
||||
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
|
||||
EXPECT(assert_equal(pA, fAa));
|
||||
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);
|
||||
|
|
|
@ -38,8 +38,8 @@ TEST( DiscreteConditional, constructors)
|
|||
EXPECT(expected1);
|
||||
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
|
||||
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
|
||||
EXPECT(expected1->endParents() == expected1->beginFrontals());
|
||||
EXPECT(expected1->endFrontals() == expected1->end());
|
||||
EXPECT(expected1->endParents() == expected1->end());
|
||||
EXPECT(expected1->endFrontals() == expected1->beginParents());
|
||||
|
||||
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
DiscreteConditional actual1(1, f1);
|
||||
|
|
|
@ -11,36 +11,37 @@
|
|||
|
||||
/**
|
||||
* @file testSignature
|
||||
* @brief Tests focusing on the details of Signatures to evaluate boost compliance
|
||||
* @brief Tests focusing on the details of Signatures to evaluate boost
|
||||
* compliance
|
||||
* @author Alex Cunningham
|
||||
* @date Sept 19th 2011
|
||||
*/
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <boost/assign/std/vector.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace boost::assign;
|
||||
|
||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
||||
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(testSignature, simple_conditional) {
|
||||
Signature sig(X | Y = "1/1 2/3 1/4");
|
||||
DiscreteKey actKey = sig.key();
|
||||
LONGS_EQUAL((long)X.first, (long)actKey.first);
|
||||
LONGS_EQUAL(X.first, actKey.first);
|
||||
|
||||
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
|
||||
LONGS_EQUAL(2, (long)actKeys.size());
|
||||
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
|
||||
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
|
||||
DiscreteKeys actKeys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, actKeys.size());
|
||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||
|
||||
vector<double> actCpt = sig.cpt();
|
||||
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
|
||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -54,17 +55,20 @@ TEST(testSignature, simple_conditional_nonparser) {
|
|||
|
||||
Signature sig(X | Y = table);
|
||||
DiscreteKey actKey = sig.key();
|
||||
EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first);
|
||||
EXPECT_LONGS_EQUAL(X.first, actKey.first);
|
||||
|
||||
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
|
||||
LONGS_EQUAL(2, (long)actKeys.size());
|
||||
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
|
||||
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
|
||||
DiscreteKeys actKeys = sig.discreteKeys();
|
||||
LONGS_EQUAL(2, actKeys.size());
|
||||
LONGS_EQUAL(X.first, actKeys.front().first);
|
||||
LONGS_EQUAL(Y.first, actKeys.back().first);
|
||||
|
||||
vector<double> actCpt = sig.cpt();
|
||||
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
|
||||
EXPECT_LONGS_EQUAL(6, actCpt.size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue