Get rid of unreliable timing

release/4.3a0
Frank Dellaert 2024-09-23 17:46:11 -07:00
parent 5aa5222edb
commit 4c7d3b5a50
1 changed files with 7 additions and 80 deletions

View File

@ -20,12 +20,9 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
// instrumented operators // instrumented operators
/* ************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() { void resetCounts() {
muls = 0; muls = 0;
adds = 0; adds = 0;
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << s << ": " << std::setw(3) << muls << " muls, " << cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms." << " adds" << endl;
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -131,7 +126,8 @@ ADT create(const Signature& signature) {
static size_t count = 0; static size_t count = 0;
const DiscreteKey& key = signature.key(); const DiscreteKey& key = signature.key();
std::stringstream ss; std::stringstream ss;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first; ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
<< key.first;
string DOTfile = ss.str(); string DOTfile = ss.str();
dot(p, DOTfile); dot(p, DOTfile);
return p; return p;
@ -143,8 +139,6 @@ TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2); D(7, 2);
resetCounts();
gttic_(asiaCPTs);
ADT pA = create(A % "99/1"); ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50"); ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5"); ADT pT = create(T | A = "99/1 95/5");
@ -153,15 +147,9 @@ TEST(ADT, joint) {
ADT pE = create((E | T, L) = "F T T T"); ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98"); ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(asiaCPTs);
tictoc_getNode(asiaCPTsNode, asiaCPTs);
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
tictoc_reset_();
printCounts("Asia CPTs");
// Create joint // Create joint
resetCounts(); resetCounts();
gttic_(asiaJoint);
ADT joint = pA; ADT joint = pA;
dot(joint, "Asia-A"); dot(joint, "Asia-A");
joint = apply(joint, pS, &mul); joint = apply(joint, pS, &mul);
@ -183,10 +171,6 @@ TEST(ADT, joint) {
#else #else
EXPECT_LONGS_EQUAL(508, muls); EXPECT_LONGS_EQUAL(508, muls);
#endif #endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint"); printCounts("Asia joint");
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S) // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
@ -208,8 +192,6 @@ TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), // DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
ADT pA = create(A % "99/1"); ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50"); ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5"); ADT pT = create(T | A = "99/1 95/5");
@ -218,15 +200,9 @@ TEST(ADT, inference) {
ADT pE = create((E | T, L) = "F T T T"); ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98"); ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(infCPTs);
tictoc_getNode(infCPTsNode, infCPTs);
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
tictoc_reset_();
// printCounts("Inference CPTs");
// Create joint // Create joint, note different ordering than above: different tree!
resetCounts(); resetCounts();
gttic_(asiaProd);
ADT joint = pA; ADT joint = pA;
dot(joint, "Joint-Product-A"); dot(joint, "Joint-Product-A");
joint = apply(joint, pS, &mul); joint = apply(joint, pS, &mul);
@ -248,14 +224,9 @@ TEST(ADT, inference) {
#else #else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif #endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
tictoc_reset_();
printCounts("Asia product"); printCounts("Asia product");
resetCounts(); resetCounts();
gttic_(asiaSum);
ADT marginal = joint; ADT marginal = joint;
marginal = marginal.combine(X, &add_); marginal = marginal.combine(X, &add_);
dot(marginal, "Joint-Sum-ADBLEST"); dot(marginal, "Joint-Sum-ADBLEST");
@ -270,10 +241,6 @@ TEST(ADT, inference) {
#else #else
EXPECT_LONGS_EQUAL(240, (long)adds); EXPECT_LONGS_EQUAL(240, (long)adds);
#endif #endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
tictoc_reset_();
printCounts("Asia sum"); printCounts("Asia sum");
} }
@ -282,7 +249,6 @@ TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts(); resetCounts();
gttic_(createCPTs);
ADT pS = create(S % "50/50"); ADT pS = create(S % "50/50");
ADT pT = create(T % "95/5"); ADT pT = create(T % "95/5");
ADT pL = create(L | S = "99/1 90/10"); ADT pL = create(L | S = "99/1 90/10");
@ -290,15 +256,10 @@ TEST(ADT, factor_graph) {
ADT pX = create(X | E = "95/5 2/98"); ADT pX = create(X | E = "95/5 2/98");
ADT pD = create(B | E = "1/8 7/9"); ADT pD = create(B | E = "1/8 7/9");
ADT pB = create(B | S = "70/30 40/60"); ADT pB = create(B | S = "70/30 40/60");
gttoc_(createCPTs); printCounts("Create CPTs");
tictoc_getNode(createCPTsNode, createCPTs);
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
tictoc_reset_();
// printCounts("Create CPTs");
// Create joint // Create joint
resetCounts(); resetCounts();
gttic_(asiaFG);
ADT fg = pS; ADT fg = pS;
fg = apply(fg, pT, &mul); fg = apply(fg, pT, &mul);
fg = apply(fg, pL, &mul); fg = apply(fg, pL, &mul);
@ -312,14 +273,9 @@ TEST(ADT, factor_graph) {
#else #else
EXPECT_LONGS_EQUAL(188, (long)muls); EXPECT_LONGS_EQUAL(188, (long)muls);
#endif #endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
tictoc_reset_();
printCounts("Asia FG"); printCounts("Asia FG");
resetCounts(); resetCounts();
gttic_(marg);
fg = fg.combine(X, &add_); fg = fg.combine(X, &add_);
dot(fg, "Marginalized-6X"); dot(fg, "Marginalized-6X");
fg = fg.combine(T, &add_); fg = fg.combine(T, &add_);
@ -335,83 +291,54 @@ TEST(ADT, factor_graph) {
#else #else
LONGS_EQUAL(62, adds); LONGS_EQUAL(62, adds);
#endif #endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
tictoc_reset_();
printCounts("marginalize"); printCounts("marginalize");
// BLESTX // BLESTX
// Eliminate X // Eliminate X
resetCounts(); resetCounts();
gttic_(elimX);
ADT fE = pX; ADT fE = pX;
dot(fE, "Eliminate-01-fEX"); dot(fE, "Eliminate-01-fEX");
fE = fE.combine(X, &add_); fE = fE.combine(X, &add_);
dot(fE, "Eliminate-02-fE"); dot(fE, "Eliminate-02-fE");
gttoc_(elimX);
tictoc_getNode(elimXNode, elimX);
elapsed = elimXNode->secs() + elimXNode->wall();
tictoc_reset_();
printCounts("Eliminate X"); printCounts("Eliminate X");
// Eliminate T // Eliminate T
resetCounts(); resetCounts();
gttic_(elimT);
ADT fLE = pT; ADT fLE = pT;
fLE = apply(fLE, pE, &mul); fLE = apply(fLE, pE, &mul);
dot(fLE, "Eliminate-03-fLET"); dot(fLE, "Eliminate-03-fLET");
fLE = fLE.combine(T, &add_); fLE = fLE.combine(T, &add_);
dot(fLE, "Eliminate-04-fLE"); dot(fLE, "Eliminate-04-fLE");
gttoc_(elimT);
tictoc_getNode(elimTNode, elimT);
elapsed = elimTNode->secs() + elimTNode->wall();
tictoc_reset_();
printCounts("Eliminate T"); printCounts("Eliminate T");
// Eliminate S // Eliminate S
resetCounts(); resetCounts();
gttic_(elimS);
ADT fBL = pS; ADT fBL = pS;
fBL = apply(fBL, pL, &mul); fBL = apply(fBL, pL, &mul);
fBL = apply(fBL, pB, &mul); fBL = apply(fBL, pB, &mul);
dot(fBL, "Eliminate-05-fBLS"); dot(fBL, "Eliminate-05-fBLS");
fBL = fBL.combine(S, &add_); fBL = fBL.combine(S, &add_);
dot(fBL, "Eliminate-06-fBL"); dot(fBL, "Eliminate-06-fBL");
gttoc_(elimS);
tictoc_getNode(elimSNode, elimS);
elapsed = elimSNode->secs() + elimSNode->wall();
tictoc_reset_();
printCounts("Eliminate S"); printCounts("Eliminate S");
// Eliminate E // Eliminate E
resetCounts(); resetCounts();
gttic_(elimE);
ADT fBL2 = fE; ADT fBL2 = fE;
fBL2 = apply(fBL2, fLE, &mul); fBL2 = apply(fBL2, fLE, &mul);
fBL2 = apply(fBL2, pD, &mul); fBL2 = apply(fBL2, pD, &mul);
dot(fBL2, "Eliminate-07-fBLE"); dot(fBL2, "Eliminate-07-fBLE");
fBL2 = fBL2.combine(E, &add_); fBL2 = fBL2.combine(E, &add_);
dot(fBL2, "Eliminate-08-fBL2"); dot(fBL2, "Eliminate-08-fBL2");
gttoc_(elimE);
tictoc_getNode(elimENode, elimE);
elapsed = elimENode->secs() + elimENode->wall();
tictoc_reset_();
printCounts("Eliminate E"); printCounts("Eliminate E");
// Eliminate L // Eliminate L
resetCounts(); resetCounts();
gttic_(elimL);
ADT fB = fBL; ADT fB = fBL;
fB = apply(fB, fBL2, &mul); fB = apply(fB, fBL2, &mul);
dot(fB, "Eliminate-09-fBL"); dot(fB, "Eliminate-09-fBL");
fB = fB.combine(L, &add_); fB = fB.combine(L, &add_);
dot(fB, "Eliminate-10-fB"); dot(fB, "Eliminate-10-fB");
gttoc_(elimL);
tictoc_getNode(elimLNode, elimL);
elapsed = elimLNode->secs() + elimLNode->wall();
tictoc_reset_();
printCounts("Eliminate L"); printCounts("Eliminate L");
} }