Added code to re-jigger Signature cpt so that frontal keys are always first, consistent with how the DiscreteElimination function works.

release/4.3a0
Frank dellaert 2020-07-12 16:50:55 -04:00
parent 9c5bba753c
commit 1ffddf72e1
7 changed files with 42 additions and 52 deletions

View File

@ -61,10 +61,9 @@ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
}
/* ******************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature) :
BaseFactor(signature.discreteKeysParentsFirst(), signature.cpt()), BaseConditional(
1) {
}
DiscreteConditional::DiscreteConditional(const Signature& signature)
: BaseFactor(signature.discreteKeys(), signature.cpt()),
BaseConditional(1) {}
/* ******************************************************************************** */
void DiscreteConditional::print(const std::string& s,

View File

@ -90,22 +90,6 @@ public:
/// GTSAM-style equals
bool equals(const DiscreteFactor& other, double tol = 1e-9) const;
/// @}
/// @name Parent keys are stored *first* in a DiscreteConditional, so re-jigger:
/// @{
/** Iterator pointing to first frontal key. */
typename DecisionTreeFactor::const_iterator beginFrontals() const { return endParents(); }
/** Iterator pointing past the last frontal key. */
typename DecisionTreeFactor::const_iterator endFrontals() const { return end(); }
/** Iterator pointing to the first parent key. */
typename DecisionTreeFactor::const_iterator beginParents() const { return begin(); }
/** Iterator pointing past the last parent key. */
typename DecisionTreeFactor::const_iterator endParents() const { return end() - nrFrontals_; }
/// @}
/// @name Standard Interface
/// @{

View File

@ -122,28 +122,30 @@ namespace gtsam {
key_(key) {
}
DiscreteKeys Signature::discreteKeysParentsFirst() const {
DiscreteKeys Signature::discreteKeys() const {
DiscreteKeys keys;
for(const DiscreteKey& key: parents_)
keys.push_back(key);
keys.push_back(key_);
for (const DiscreteKey& key : parents_) keys.push_back(key);
return keys;
}
KeyVector Signature::indices() const {
KeyVector js;
js.push_back(key_.first);
for(const DiscreteKey& key: parents_)
js.push_back(key.first);
for (const DiscreteKey& key : parents_) js.push_back(key.first);
return js;
}
vector<double> Signature::cpt() const {
vector<double> cpt;
if (table_) {
for(const Row& row: *table_)
for(const double& x: row)
cpt.push_back(x);
const size_t nrStates = table_->at(0).size();
for (size_t j = 0; j < nrStates; j++) {
for (const Row& row : *table_) {
assert(row.size() == nrStates);
cpt.push_back(row[j]);
}
}
}
return cpt;
}

View File

@ -86,8 +86,8 @@ namespace gtsam {
return parents_;
}
/** All keys, with variable key last */
DiscreteKeys discreteKeysParentsFirst() const;
/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;
/** All key indices, with variable key first */
KeyVector indices() const;

View File

@ -132,7 +132,7 @@ TEST(ADT, example3)
/** Convert Signature into CPT */
ADT create(const Signature& signature) {
ADT p(signature.discreteKeysParentsFirst(), signature.cpt());
ADT p(signature.discreteKeys(), signature.cpt());
static size_t count = 0;
const DiscreteKey& key = signature.key();
string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str();
@ -181,19 +181,20 @@ TEST(ADT, joint)
dot(joint, "Asia-ASTLBEX");
joint = apply(joint, pD, &mul);
dot(joint, "Asia-ASTLBEXD");
EXPECT_LONGS_EQUAL(346, (long)muls);
EXPECT_LONGS_EQUAL(346, muls);
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint");
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
ADT pASTL = pA;
pASTL = apply(pASTL, pS, &mul);
pASTL = apply(pASTL, pT, &mul);
pASTL = apply(pASTL, pL, &mul);
// test combine
// test combine to check that P(A) = \sum_{S,T,L} P(A,S,T,L)
ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_);
EXPECT(assert_equal(pA, fAa));
ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_);

View File

@ -38,8 +38,8 @@ TEST( DiscreteConditional, constructors)
EXPECT(expected1);
EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals()));
EXPECT_LONGS_EQUAL(2, *(expected1->beginParents()));
EXPECT(expected1->endParents() == expected1->beginFrontals());
EXPECT(expected1->endFrontals() == expected1->end());
EXPECT(expected1->endParents() == expected1->end());
EXPECT(expected1->endFrontals() == expected1->beginParents());
DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
DiscreteConditional actual1(1, f1);

View File

@ -11,36 +11,37 @@
/**
* @file testSignature
* @brief Tests focusing on the details of Signatures to evaluate boost compliance
* @brief Tests focusing on the details of Signatures to evaluate boost
* compliance
* @author Alex Cunningham
* @date Sept 19th 2011
*/
#include <boost/assign/std/vector.hpp>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h>
#include <boost/assign/std/vector.hpp>
using namespace std;
using namespace gtsam;
using namespace boost::assign;
DiscreteKey X(0,2), Y(1,3), Z(2,2);
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
/* ************************************************************************* */
TEST(testSignature, simple_conditional) {
Signature sig(X | Y = "1/1 2/3 1/4");
DiscreteKey actKey = sig.key();
LONGS_EQUAL((long)X.first, (long)actKey.first);
LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
LONGS_EQUAL(2, (long)actKeys.size());
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
EXPECT_LONGS_EQUAL(6, actCpt.size());
}
/* ************************************************************************* */
@ -54,17 +55,20 @@ TEST(testSignature, simple_conditional_nonparser) {
Signature sig(X | Y = table);
DiscreteKey actKey = sig.key();
EXPECT_LONGS_EQUAL((long)X.first, (long)actKey.first);
EXPECT_LONGS_EQUAL(X.first, actKey.first);
DiscreteKeys actKeys = sig.discreteKeysParentsFirst();
LONGS_EQUAL(2, (long)actKeys.size());
LONGS_EQUAL((long)Y.first, (long)actKeys.front().first);
LONGS_EQUAL((long)X.first, (long)actKeys.back().first);
DiscreteKeys actKeys = sig.discreteKeys();
LONGS_EQUAL(2, actKeys.size());
LONGS_EQUAL(X.first, actKeys.front().first);
LONGS_EQUAL(Y.first, actKeys.back().first);
vector<double> actCpt = sig.cpt();
EXPECT_LONGS_EQUAL(6, (long)actCpt.size());
EXPECT_LONGS_EQUAL(6, actCpt.size());
}
/* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */