Added code to re-jigger Signature cpt so that frontal keys are always first, consistent with how the DiscreteElimination function works.

release/4.3a0
Frank dellaert 2020-07-12 16:50:55 -04:00
parent 9c5bba753c
commit 1ffddf72e1
7 changed files with 42 additions and 52 deletions

View File

@ -61,10 +61,9 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
} }
/* ******************************************************************************** */ /* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature) : DiscreteConditional::DiscreteConditional(const Signature& signature)
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional( : BaseFactor(signature.discreteKeys(), signature.cpt()),
1) { BaseConditional(1) {}
}
/* ******************************************************************************** */ /* ******************************************************************************** */
void DiscreteConditional::print(const std::string& s, void DiscreteConditional::print(const std::string& s,

View File

@ -90,22 +90,6 @@ public:
/// GTSAM-style equals /// GTSAM-style equals
bool equals(const DiscreteFactor& other, double tol = 1e-9) const; bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
/// @}
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
/// @{
/** Iterator pointing to first frontal key. */
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
/** Iterator pointing past the last frontal key. */
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
/** Iterator pointing to the first parent key. */
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
/** Iterator pointing past the last parent key. */
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{

View File

@ -122,28 +122,30 @@ namespace gtsam {
key_(key) { key_(key) {
} }
DiscreteKeys Signature::discreteKeysParentsFirst() const { DiscreteKeys Signature::discreteKeys() const {
DiscreteKeys keys; DiscreteKeys keys;
for(const DiscreteKey& key: parents_)
keys.push_back(key);
keys.push_back(key_); keys.push_back(key_);
for (const DiscreteKey& key : parents_) keys.push_back(key);
return keys; return keys;
} }
KeyVector Signature::indices() const { KeyVector Signature::indices() const {
KeyVector js; KeyVector js;
js.push_back(key_.first); js.push_back(key_.first);
for(const DiscreteKey& key: parents_) for (const DiscreteKey& key : parents_) js.push_back(key.first);
js.push_back(key.first);
return js; return js;
} }
vector<double> Signature::cpt() const { vector<double> Signature::cpt() const {
vector<double> cpt; vector<double> cpt;
if (table_) { if (table_) {
for(const Row& row: *table_) const size_t nrStates = table_->at(0).size();
for(const double& x: row) for (size_t j = 0; j < nrStates; j++) {
cpt.push_back(x); for (const Row& row : *table_) {
assert(row.size() == nrStates);
cpt.push_back(row[j]);
}
}
} }
return cpt; return cpt;
} }

View File

@ -86,8 +86,8 @@ namespace gtsam {
return parents_; return parents_;
} }
/** All keys, with variable key last */ /** All keys, with variable key first */
DiscreteKeys discreteKeysParentsFirst() const; DiscreteKeys discreteKeys() const;
/** All key indices, with variable key first */ /** All key indices, with variable key first */
KeyVector indices() const; KeyVector indices() const;

View File

@ -132,7 +132,7 @@ TEST(ADT, example3)
/** Convert Signature into CPT */ /** Convert Signature into CPT */
ADT create(const Signature& signature) { ADT create(const Signature& signature) {
ADT p(signature.discreteKeysParentsFirst(), signature.cpt()); ADT p(signature.discreteKeys(), signature.cpt());
static size_t count = 0; static size_t count = 0;
const DiscreteKey& key = signature.key(); const DiscreteKey& key = signature.key();
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
@ -181,19 +181,20 @@ TEST(ADT, joint)
dot(joint, "Asia-ASTLBEX"); dot(joint, "Asia-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Asia-ASTLBEXD"); dot(joint, "Asia-ASTLBEXD");
EXPECT_LONGS_EQUAL(346, (long)muls); EXPECT_LONGS_EQUAL(346, muls);
gttoc_(asiaJoint); gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint); tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall(); elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_(); tictoc_reset_();
printCounts("Asia joint"); printCounts("Asia joint");
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
ADT pASTL = pA; ADT pASTL = pA;
pASTL = apply(pASTL, pS, &mul); pASTL = apply(pASTL, pS, &mul);
pASTL = apply(pASTL, pT, &mul); pASTL = apply(pASTL, pT, &mul);
pASTL = apply(pASTL, pL, &mul); pASTL = apply(pASTL, pL, &mul);
// test combine // test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_); ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
EXPECT(assert_equal(pA, fAa)); EXPECT(assert_equal(pA, fAa));
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_); ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);

View File

@ -38,8 +38,8 @@ TEST( DiscreteConditional, constructors)
EXPECT(expected1); EXPECT(expected1);
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
EXPECT(expected1->endParents() == expected1->beginFrontals()); EXPECT(expected1->endParents() == expected1->end());
EXPECT(expected1->endFrontals() == expected1->end()); EXPECT(expected1->endFrontals() == expected1->beginParents());
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1); DiscreteConditional actual1(1, f1);

View File

@ -11,36 +11,37 @@
/** /**
* @file testSignature * @file testSignature
* @brief Tests focusing on the details of Signatures to evaluate boost compliance * @brief Tests focusing on the details of Signatures to evaluate boost
* compliance
* @author Alex Cunningham * @author Alex Cunningham
* @date Sept 19th 2011 * @date Sept 19th 2011
*/ */
#include <boost/assign/std/vector.hpp>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <boost/assign/std/vector.hpp>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace boost::assign; using namespace boost::assign;
DiscreteKey X(0,2), Y(1,3), Z(2,2); DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(testSignature, simple_conditional) { TEST(testSignature, simple_conditional) {
Signature sig(X | Y = "1/1 2/3 1/4"); Signature sig(X | Y = "1/1 2/3 1/4");
DiscreteKey actKey = sig.key(); DiscreteKey actKey = sig.key();
LONGS_EQUAL((long)X.first, (long)actKey.first); LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeysParentsFirst(); DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, (long)actKeys.size()); LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first); LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL((long)X.first, (long)actKeys.back().first); LONGS_EQUAL(Y.first, actKeys.back().first);
vector<double> actCpt = sig.cpt(); vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, (long)actCpt.size()); EXPECT_LONGS_EQUAL(6, actCpt.size());
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -54,17 +55,20 @@ TEST(testSignature, simple_conditional_nonparser) {
Signature sig(X | Y = table); Signature sig(X | Y = table);
DiscreteKey actKey = sig.key(); DiscreteKey actKey = sig.key();
EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first); EXPECT_LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeysParentsFirst(); DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, (long)actKeys.size()); LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first); LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL((long)X.first, (long)actKeys.back().first); LONGS_EQUAL(Y.first, actKeys.back().first);
vector<double> actCpt = sig.cpt(); vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, (long)actCpt.size()); EXPECT_LONGS_EQUAL(6, actCpt.size());
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); } int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */ /* ************************************************************************* */