Get rid of unreliable timing

release/4.3a0
Frank Dellaert 2024-09-23 17:46:11 -07:00
parent 5aa5222edb
commit 4c7d3b5a50
1 changed files with 7 additions and 80 deletions

View File

@ -20,12 +20,9 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
// instrumented operators
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
muls = 0;
adds = 0;
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << s << ": " << std::setw(3) << muls << " muls, " <<
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms."
<< endl;
cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
<< " adds" << endl;
#endif
resetCounts();
}
@ -131,7 +126,8 @@ ADT create(const Signature& signature) {
static size_t count = 0;
const DiscreteKey& key = signature.key();
std::stringstream ss;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
<< key.first;
string DOTfile = ss.str();
dot(p, DOTfile);
return p;
@ -143,8 +139,6 @@ TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
resetCounts();
gttic_(asiaCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
@ -153,15 +147,9 @@ TEST(ADT, joint) {
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(asiaCPTs);
tictoc_getNode(asiaCPTsNode, asiaCPTs);
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
tictoc_reset_();
printCounts("Asia CPTs");
// Create joint
resetCounts();
gttic_(asiaJoint);
ADT joint = pA;
dot(joint, "Asia-A");
joint = apply(joint, pS, &mul);
@ -183,10 +171,6 @@ TEST(ADT, joint) {
#else
EXPECT_LONGS_EQUAL(508, muls);
#endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint");
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
@ -208,8 +192,6 @@ TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
@ -218,15 +200,9 @@ TEST(ADT, inference) {
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(infCPTs);
tictoc_getNode(infCPTsNode, infCPTs);
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
tictoc_reset_();
// printCounts("Inference CPTs");
// Create joint
// Create joint, note different ordering than above: different tree!
resetCounts();
gttic_(asiaProd);
ADT joint = pA;
dot(joint, "Joint-Product-A");
joint = apply(joint, pS, &mul);
@ -248,14 +224,9 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
tictoc_reset_();
printCounts("Asia product");
resetCounts();
gttic_(asiaSum);
ADT marginal = joint;
marginal = marginal.combine(X, &add_);
dot(marginal, "Joint-Sum-ADBLEST");
@ -270,10 +241,6 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(240, (long)adds);
#endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
tictoc_reset_();
printCounts("Asia sum");
}
@ -282,7 +249,6 @@ TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts();
gttic_(createCPTs);
ADT pS = create(S % "50/50");
ADT pT = create(T % "95/5");
ADT pL = create(L | S = "99/1 90/10");
@ -290,15 +256,10 @@ TEST(ADT, factor_graph) {
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create(B | E = "1/8 7/9");
ADT pB = create(B | S = "70/30 40/60");
gttoc_(createCPTs);
tictoc_getNode(createCPTsNode, createCPTs);
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
tictoc_reset_();
// printCounts("Create CPTs");
printCounts("Create CPTs");
// Create joint
resetCounts();
gttic_(asiaFG);
ADT fg = pS;
fg = apply(fg, pT, &mul);
fg = apply(fg, pL, &mul);
@ -312,14 +273,9 @@ TEST(ADT, factor_graph) {
#else
EXPECT_LONGS_EQUAL(188, (long)muls);
#endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
tictoc_reset_();
printCounts("Asia FG");
resetCounts();
gttic_(marg);
fg = fg.combine(X, &add_);
dot(fg, "Marginalized-6X");
fg = fg.combine(T, &add_);
@ -335,83 +291,54 @@ TEST(ADT, factor_graph) {
#else
LONGS_EQUAL(62, adds);
#endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
tictoc_reset_();
printCounts("marginalize");
// BLESTX
// Eliminate X
resetCounts();
gttic_(elimX);
ADT fE = pX;
dot(fE, "Eliminate-01-fEX");
fE = fE.combine(X, &add_);
dot(fE, "Eliminate-02-fE");
gttoc_(elimX);
tictoc_getNode(elimXNode, elimX);
elapsed = elimXNode->secs() + elimXNode->wall();
tictoc_reset_();
printCounts("Eliminate X");
// Eliminate T
resetCounts();
gttic_(elimT);
ADT fLE = pT;
fLE = apply(fLE, pE, &mul);
dot(fLE, "Eliminate-03-fLET");
fLE = fLE.combine(T, &add_);
dot(fLE, "Eliminate-04-fLE");
gttoc_(elimT);
tictoc_getNode(elimTNode, elimT);
elapsed = elimTNode->secs() + elimTNode->wall();
tictoc_reset_();
printCounts("Eliminate T");
// Eliminate S
resetCounts();
gttic_(elimS);
ADT fBL = pS;
fBL = apply(fBL, pL, &mul);
fBL = apply(fBL, pB, &mul);
dot(fBL, "Eliminate-05-fBLS");
fBL = fBL.combine(S, &add_);
dot(fBL, "Eliminate-06-fBL");
gttoc_(elimS);
tictoc_getNode(elimSNode, elimS);
elapsed = elimSNode->secs() + elimSNode->wall();
tictoc_reset_();
printCounts("Eliminate S");
// Eliminate E
resetCounts();
gttic_(elimE);
ADT fBL2 = fE;
fBL2 = apply(fBL2, fLE, &mul);
fBL2 = apply(fBL2, pD, &mul);
dot(fBL2, "Eliminate-07-fBLE");
fBL2 = fBL2.combine(E, &add_);
dot(fBL2, "Eliminate-08-fBL2");
gttoc_(elimE);
tictoc_getNode(elimENode, elimE);
elapsed = elimENode->secs() + elimENode->wall();
tictoc_reset_();
printCounts("Eliminate E");
// Eliminate L
resetCounts();
gttic_(elimL);
ADT fB = fBL;
fB = apply(fB, fBL2, &mul);
dot(fB, "Eliminate-09-fBL");
fB = fB.combine(L, &add_);
dot(fB, "Eliminate-10-fB");
gttoc_(elimL);
tictoc_getNode(elimLNode, elimL);
elapsed = elimLNode->secs() + elimLNode->wall();
tictoc_reset_();
printCounts("Eliminate L");
}