Merge pull request #1844 from borglab/feature/timeHybrid

release/4.3a0
Varun Agrawal 2024-09-24 15:32:03 -04:00 committed by GitHub
commit e4ec8d3b9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 129 additions and 136 deletions

View File

@ -91,7 +91,7 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
std::string value = valueFormatter(constant_);
const std::string value = valueFormatter(constant_);
if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
@ -306,7 +306,8 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter,
bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
const std::string label = labelFormatter(label_);
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label
<< "\"]\n";
size_t B = branches_.size();
for (size_t i = 0; i < B; i++) {

View File

@ -147,14 +147,14 @@ namespace gtsam {
size_t i;
ADT result(*this);
for (i = 0; i < nrFrontals; i++) {
Key j = keys()[i];
Key j = keys_[i];
result = result.combine(j, cardinality(j), op);
}
// create new factor, note we start keys after nrFrontals
// Create new factor, note we start with keys after nrFrontals:
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
for (; i < keys_.size(); i++) {
Key j = keys_[i];
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return std::make_shared<DecisionTreeFactor>(dkeys, result);
@ -179,24 +179,22 @@ namespace gtsam {
result = result.combine(j, cardinality(j), op);
}
// create new factor, note we collect keys that are not in frontalKeys
/*
Due to branch merging, the labels in `result` may be missing some keys
Create new factor, note we collect keys that are not in frontalKeys.
Due to branch merging, the labels in `result` may be missing some keys.
E.g. After branch merging, we may get a ADT like:
Leaf [2] 1.0204082
This is missing the key values used for branching.
Hence, code below traverses the original keys and omits those in
frontalKeys. We loop over cardinalities, which is O(n) even for a map, and
then "contains" is a binary search on a small vector.
*/
KeyVector difference, frontalKeys_(frontalKeys), keys_(keys());
// Get the difference of the frontalKeys and the factor keys using set_difference
std::sort(keys_.begin(), keys_.end());
std::sort(frontalKeys_.begin(), frontalKeys_.end());
std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(),
frontalKeys_.end(), back_inserter(difference));
DiscreteKeys dkeys;
for (Key key : difference) {
dkeys.push_back(DiscreteKey(key, cardinality(key)));
for (auto&& [key, cardinality] : cardinalities_) {
if (!frontalKeys.contains(key)) {
dkeys.push_back(DiscreteKey(key, cardinality));
}
}
return std::make_shared<DecisionTreeFactor>(dkeys, result);
}

View File

@ -20,12 +20,9 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
// instrumented operators
/* ************************************************************************** */
size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() {
muls = 0;
adds = 0;
}
void printCounts(const string& s) {
#ifndef DISABLE_TIMING
cout << s << ": " << std::setw(3) << muls << " muls, " <<
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms."
<< endl;
cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
<< " adds" << endl;
#endif
resetCounts();
}
@ -131,37 +126,35 @@ ADT create(const Signature& signature) {
static size_t count = 0;
const DiscreteKey& key = signature.key();
std::stringstream ss;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
<< key.first;
string DOTfile = ss.str();
dot(p, DOTfile);
return p;
}
/* ************************************************************************* */
namespace asiaCPTs {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
} // namespace asiaCPTs
/* ************************************************************************* */
// test Asia Joint
TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
resetCounts();
gttic_(asiaCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(asiaCPTs);
tictoc_getNode(asiaCPTsNode, asiaCPTs);
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
tictoc_reset_();
printCounts("Asia CPTs");
using namespace asiaCPTs;
// Create joint
resetCounts();
gttic_(asiaJoint);
ADT joint = pA;
dot(joint, "Asia-A");
joint = apply(joint, pS, &mul);
@ -183,11 +176,12 @@ TEST(ADT, joint) {
#else
EXPECT_LONGS_EQUAL(508, muls);
#endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint");
}
/* ************************************************************************* */
TEST(ADT, combine) {
using namespace asiaCPTs;
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
ADT pASTL = pA;
@ -203,13 +197,11 @@ TEST(ADT, joint) {
}
/* ************************************************************************* */
// test Inference with joint
// test Inference with joint, created using different ordering
TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
@ -218,15 +210,9 @@ TEST(ADT, inference) {
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(infCPTs);
tictoc_getNode(infCPTsNode, infCPTs);
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
tictoc_reset_();
// printCounts("Inference CPTs");
// Create joint
// Create joint, note different ordering than above: different tree!
resetCounts();
gttic_(asiaProd);
ADT joint = pA;
dot(joint, "Joint-Product-A");
joint = apply(joint, pS, &mul);
@ -248,14 +234,9 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
tictoc_reset_();
printCounts("Asia product");
resetCounts();
gttic_(asiaSum);
ADT marginal = joint;
marginal = marginal.combine(X, &add_);
dot(marginal, "Joint-Sum-ADBLEST");
@ -270,10 +251,6 @@ TEST(ADT, inference) {
#else
EXPECT_LONGS_EQUAL(240, (long)adds);
#endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
tictoc_reset_();
printCounts("Asia sum");
}
@ -281,8 +258,6 @@ TEST(ADT, inference) {
TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts();
gttic_(createCPTs);
ADT pS = create(S % "50/50");
ADT pT = create(T % "95/5");
ADT pL = create(L | S = "99/1 90/10");
@ -290,15 +265,9 @@ TEST(ADT, factor_graph) {
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create(B | E = "1/8 7/9");
ADT pB = create(B | S = "70/30 40/60");
gttoc_(createCPTs);
tictoc_getNode(createCPTsNode, createCPTs);
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
tictoc_reset_();
// printCounts("Create CPTs");
// Create joint
resetCounts();
gttic_(asiaFG);
ADT fg = pS;
fg = apply(fg, pT, &mul);
fg = apply(fg, pL, &mul);
@ -312,14 +281,9 @@ TEST(ADT, factor_graph) {
#else
EXPECT_LONGS_EQUAL(188, (long)muls);
#endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
tictoc_reset_();
printCounts("Asia FG");
resetCounts();
gttic_(marg);
fg = fg.combine(X, &add_);
dot(fg, "Marginalized-6X");
fg = fg.combine(T, &add_);
@ -335,83 +299,54 @@ TEST(ADT, factor_graph) {
#else
LONGS_EQUAL(62, adds);
#endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
tictoc_reset_();
printCounts("marginalize");
// BLESTX
// Eliminate X
resetCounts();
gttic_(elimX);
ADT fE = pX;
dot(fE, "Eliminate-01-fEX");
fE = fE.combine(X, &add_);
dot(fE, "Eliminate-02-fE");
gttoc_(elimX);
tictoc_getNode(elimXNode, elimX);
elapsed = elimXNode->secs() + elimXNode->wall();
tictoc_reset_();
printCounts("Eliminate X");
// Eliminate T
resetCounts();
gttic_(elimT);
ADT fLE = pT;
fLE = apply(fLE, pE, &mul);
dot(fLE, "Eliminate-03-fLET");
fLE = fLE.combine(T, &add_);
dot(fLE, "Eliminate-04-fLE");
gttoc_(elimT);
tictoc_getNode(elimTNode, elimT);
elapsed = elimTNode->secs() + elimTNode->wall();
tictoc_reset_();
printCounts("Eliminate T");
// Eliminate S
resetCounts();
gttic_(elimS);
ADT fBL = pS;
fBL = apply(fBL, pL, &mul);
fBL = apply(fBL, pB, &mul);
dot(fBL, "Eliminate-05-fBLS");
fBL = fBL.combine(S, &add_);
dot(fBL, "Eliminate-06-fBL");
gttoc_(elimS);
tictoc_getNode(elimSNode, elimS);
elapsed = elimSNode->secs() + elimSNode->wall();
tictoc_reset_();
printCounts("Eliminate S");
// Eliminate E
resetCounts();
gttic_(elimE);
ADT fBL2 = fE;
fBL2 = apply(fBL2, fLE, &mul);
fBL2 = apply(fBL2, pD, &mul);
dot(fBL2, "Eliminate-07-fBLE");
fBL2 = fBL2.combine(E, &add_);
dot(fBL2, "Eliminate-08-fBL2");
gttoc_(elimE);
tictoc_getNode(elimENode, elimE);
elapsed = elimENode->secs() + elimENode->wall();
tictoc_reset_();
printCounts("Eliminate E");
// Eliminate L
resetCounts();
gttic_(elimL);
ADT fB = fBL;
fB = apply(fB, fBL2, &mul);
dot(fB, "Eliminate-09-fBL");
fB = fB.combine(L, &add_);
dot(fB, "Eliminate-10-fB");
gttoc_(elimL);
tictoc_getNode(elimLNode, elimL);
elapsed = elimLNode->secs() + elimLNode->wall();
tictoc_reset_();
printCounts("Eliminate L");
}

View File

@ -22,7 +22,10 @@
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Ordering.h>
using namespace std;
using namespace gtsam;
@ -33,25 +36,24 @@ TEST(DecisionTreeFactor, ConstructorsMatch) {
DiscreteKey X(0, 2), Y(1, 3);
// Create with vector and with string
const std::vector<double> table {2, 5, 3, 6, 4, 7};
const std::vector<double> table{2, 5, 3, 6, 4, 7};
DecisionTreeFactor f1({X, Y}, table);
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
EXPECT(assert_equal(f1, f2));
}
/* ************************************************************************* */
TEST( DecisionTreeFactor, constructors)
{
TEST(DecisionTreeFactor, constructors) {
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
// Create factors
DecisionTreeFactor f1(X, {2, 8});
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size());
EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size());
EXPECT_LONGS_EQUAL(1, f1.size());
EXPECT_LONGS_EQUAL(2, f2.size());
EXPECT_LONGS_EQUAL(3, f3.size());
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
@ -70,7 +72,7 @@ TEST( DecisionTreeFactor, constructors)
/* ************************************************************************* */
TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
// Create factors
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
@ -104,9 +106,8 @@ TEST(DecisionTreeFactor, multiplication) {
}
/* ************************************************************************* */
TEST( DecisionTreeFactor, sum_max)
{
DiscreteKey v0(0,3), v1(1,2);
TEST(DecisionTreeFactor, sum_max) {
DiscreteKey v0(0, 3), v1(1, 2);
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12");
@ -165,22 +166,85 @@ TEST(DecisionTreeFactor, Prune) {
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
DecisionTreeFactor expected3(
D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");
DecisionTreeFactor expected3(D & C & B & A,
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5;
auto pruned3 = factor.prune(maxNrAssignments);
EXPECT(assert_equal(expected3, pruned3));
}
/* ************************************************************************** */
// Asia Bayes Network
/* ************************************************************************** */
#define DISABLE_DOT
void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#ifndef DISABLE_DOT
std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
auto formatter = [&](Key key) { return names[key]; };
f.dot(filename, formatter, true);
#endif
}
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
gttic_(asiaCPTs);
DecisionTreeFactor pA = create(A % "99/1");
DecisionTreeFactor pS = create(S % "50/50");
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
// Create joint
gttic_(asiaJoint);
DecisionTreeFactor joint = pA;
maybeSaveDotFile(joint, "Asia-A");
joint = joint * pS;
maybeSaveDotFile(joint, "Asia-AS");
joint = joint * pT;
maybeSaveDotFile(joint, "Asia-AST");
joint = joint * pL;
maybeSaveDotFile(joint, "Asia-ASTL");
joint = joint * pB;
maybeSaveDotFile(joint, "Asia-ASTLB");
joint = joint * pE;
maybeSaveDotFile(joint, "Asia-ASTLBE");
joint = joint * pX;
maybeSaveDotFile(joint, "Asia-ASTLBEX");
joint = joint * pD;
maybeSaveDotFile(joint, "Asia-ASTLBEXD");
// Check that discrete keys are as expected
EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
// Check that summing out variables maintains the keys even if merged, as is
// the case with S.
auto noAB = joint.sum(Ordering{A.first, B.first});
EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
}
/* ************************************************************************* */
TEST(DecisionTreeFactor, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
for (bool showZero:{true, false}) {
for (bool showZero : {true, false}) {
string actual = f.dot(formatter, showZero);
// pretty weak test, as ids are pointers and not stable across platforms.
string expected = "digraph G {";

View File

@ -22,7 +22,7 @@ namespace gtsam {
/* *******************************************************************************/
static void checkKeys(const KeyVector& continuousKeys,
std::vector<NonlinearFactorValuePair>& pairs) {
const std::vector<NonlinearFactorValuePair>& pairs) {
KeySet factor_keys_set;
for (const auto& pair : pairs) {
auto f = pair.first;
@ -55,14 +55,9 @@ HybridNonlinearFactor::HybridNonlinearFactor(
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(
const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& factors)
const std::vector<NonlinearFactorValuePair>& pairs)
: Base(continuousKeys, {discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end());
KeySet factor_keys_set;
for (auto&& [f, val] : factors) {
pairs.emplace_back(f, val);
}
checkKeys(continuousKeys, pairs);
factors_ = FactorValuePairs({discreteKey}, pairs);
}

View File

@ -106,11 +106,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
*
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKey The discrete key for the "mode", indexing components.
* @param factors Vector of gaussian factor-scalar pairs, one per mode.
* @param pairs Vector of gaussian factor-scalar pairs, one per mode.
*/
HybridNonlinearFactor(const KeyVector& continuousKeys,
const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& factors);
const std::vector<NonlinearFactorValuePair>& pairs);
/**
* @brief Construct a new HybridNonlinearFactor on a several discrete keys M,