Merge pull request #1844 from borglab/feature/timeHybrid
						commit
						e4ec8d3b9c
					
				|  | @ -91,7 +91,7 @@ namespace gtsam { | |||
|     void dot(std::ostream& os, const LabelFormatter& labelFormatter, | ||||
|              const ValueFormatter& valueFormatter, | ||||
|              bool showZero) const override { | ||||
|       std::string value = valueFormatter(constant_); | ||||
|       const std::string value = valueFormatter(constant_); | ||||
|       if (showZero || value.compare("0")) | ||||
|         os << "\"" << this->id() << "\" [label=\"" << value | ||||
|            << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; | ||||
|  | @ -306,7 +306,8 @@ namespace gtsam { | |||
|     void dot(std::ostream& os, const LabelFormatter& labelFormatter, | ||||
|              const ValueFormatter& valueFormatter, | ||||
|              bool showZero) const override { | ||||
|       os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ | ||||
|       const std::string label = labelFormatter(label_); | ||||
|       os << "\"" << this->id() << "\" [shape=circle, label=\"" << label | ||||
|           << "\"]\n"; | ||||
|       size_t B = branches_.size(); | ||||
|       for (size_t i = 0; i < B; i++) { | ||||
|  |  | |||
|  | @ -147,14 +147,14 @@ namespace gtsam { | |||
|     size_t i; | ||||
|     ADT result(*this); | ||||
|     for (i = 0; i < nrFrontals; i++) { | ||||
|       Key j = keys()[i]; | ||||
|       Key j = keys_[i]; | ||||
|       result = result.combine(j, cardinality(j), op); | ||||
|     } | ||||
| 
 | ||||
|     // create new factor, note we start keys after nrFrontals
 | ||||
|     // Create new factor, note we start with keys after nrFrontals:
 | ||||
|     DiscreteKeys dkeys; | ||||
|     for (; i < keys().size(); i++) { | ||||
|       Key j = keys()[i]; | ||||
|     for (; i < keys_.size(); i++) { | ||||
|       Key j = keys_[i]; | ||||
|       dkeys.push_back(DiscreteKey(j, cardinality(j))); | ||||
|     } | ||||
|     return std::make_shared<DecisionTreeFactor>(dkeys, result); | ||||
|  | @ -179,24 +179,22 @@ namespace gtsam { | |||
|       result = result.combine(j, cardinality(j), op); | ||||
|     } | ||||
| 
 | ||||
|     // create new factor, note we collect keys that are not in frontalKeys
 | ||||
|     /*
 | ||||
|     Due to branch merging, the labels in `result` may be missing some keys | ||||
|     Create new factor, note we collect keys that are not in frontalKeys. | ||||
|      | ||||
|     Due to branch merging, the labels in `result` may be missing some keys. | ||||
|     E.g. After branch merging, we may get a ADT like: | ||||
|       Leaf [2] 1.0204082 | ||||
| 
 | ||||
|     This is missing the key values used for branching. | ||||
|     Hence, code below traverses the original keys and omits those in | ||||
|     frontalKeys. We loop over cardinalities, which is O(n) even for a map, and | ||||
|     then "contains" is a binary search on a small vector. | ||||
|     */ | ||||
|     KeyVector difference, frontalKeys_(frontalKeys), keys_(keys()); | ||||
|     // Get the difference of the frontalKeys and the factor keys using set_difference
 | ||||
|     std::sort(keys_.begin(), keys_.end()); | ||||
|     std::sort(frontalKeys_.begin(), frontalKeys_.end()); | ||||
|     std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(), | ||||
|                         frontalKeys_.end(), back_inserter(difference)); | ||||
| 
 | ||||
|     DiscreteKeys dkeys; | ||||
|     for (Key key : difference) { | ||||
|       dkeys.push_back(DiscreteKey(key, cardinality(key))); | ||||
|     for (auto&& [key, cardinality] : cardinalities_) { | ||||
|       if (!frontalKeys.contains(key)) {  | ||||
|         dkeys.push_back(DiscreteKey(key, cardinality)); | ||||
|       } | ||||
|     } | ||||
|     return std::make_shared<DecisionTreeFactor>(dkeys, result); | ||||
|   } | ||||
|  |  | |||
|  | @ -20,12 +20,9 @@ | |||
| #include <gtsam/discrete/DiscreteKey.h>  // make sure we have traits | ||||
| #include <gtsam/discrete/DiscreteValues.h> | ||||
| // headers first to make sure no missing headers
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| #include <gtsam/discrete/AlgebraicDecisionTree.h> | ||||
| #include <gtsam/discrete/DecisionTree-inl.h>  // for convert only | ||||
| #define DISABLE_TIMING | ||||
| 
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| #include <gtsam/base/timing.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| 
 | ||||
| using namespace std; | ||||
|  | @ -71,16 +68,14 @@ void dot(const T& f, const string& filename) { | |||
| // instrumented operators
 | ||||
| /* ************************************************************************** */ | ||||
| size_t muls = 0, adds = 0; | ||||
| double elapsed; | ||||
| void resetCounts() { | ||||
|   muls = 0; | ||||
|   adds = 0; | ||||
| } | ||||
| void printCounts(const string& s) { | ||||
| #ifndef DISABLE_TIMING | ||||
| cout << s << ": " << std::setw(3) << muls << " muls, " <<  | ||||
|   std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms." | ||||
|      << endl; | ||||
|   cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds | ||||
|        << " adds" << endl; | ||||
| #endif | ||||
|   resetCounts(); | ||||
| } | ||||
|  | @ -131,37 +126,35 @@ ADT create(const Signature& signature) { | |||
|   static size_t count = 0; | ||||
|   const DiscreteKey& key = signature.key(); | ||||
|   std::stringstream ss; | ||||
|   ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first; | ||||
|   ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" | ||||
|      << key.first; | ||||
|   string DOTfile = ss.str(); | ||||
|   dot(p, DOTfile); | ||||
|   return p; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| namespace asiaCPTs { | ||||
| DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), | ||||
|     D(7, 2); | ||||
| 
 | ||||
| ADT pA = create(A % "99/1"); | ||||
| ADT pS = create(S % "50/50"); | ||||
| ADT pT = create(T | A = "99/1 95/5"); | ||||
| ADT pL = create(L | S = "99/1 90/10"); | ||||
| ADT pB = create(B | S = "70/30 40/60"); | ||||
| ADT pE = create((E | T, L) = "F T T T"); | ||||
| ADT pX = create(X | E = "95/5 2/98"); | ||||
| ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); | ||||
| }  // namespace asiaCPTs
 | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // test Asia Joint
 | ||||
| TEST(ADT, joint) { | ||||
|   DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), | ||||
|       D(7, 2); | ||||
| 
 | ||||
|   resetCounts(); | ||||
|   gttic_(asiaCPTs); | ||||
|   ADT pA = create(A % "99/1"); | ||||
|   ADT pS = create(S % "50/50"); | ||||
|   ADT pT = create(T | A = "99/1 95/5"); | ||||
|   ADT pL = create(L | S = "99/1 90/10"); | ||||
|   ADT pB = create(B | S = "70/30 40/60"); | ||||
|   ADT pE = create((E | T, L) = "F T T T"); | ||||
|   ADT pX = create(X | E = "95/5 2/98"); | ||||
|   ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); | ||||
|   gttoc_(asiaCPTs); | ||||
|   tictoc_getNode(asiaCPTsNode, asiaCPTs); | ||||
|   elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Asia CPTs"); | ||||
|   using namespace asiaCPTs; | ||||
| 
 | ||||
|   // Create joint
 | ||||
|   resetCounts(); | ||||
|   gttic_(asiaJoint); | ||||
|   ADT joint = pA; | ||||
|   dot(joint, "Asia-A"); | ||||
|   joint = apply(joint, pS, &mul); | ||||
|  | @ -183,11 +176,12 @@ TEST(ADT, joint) { | |||
| #else | ||||
|   EXPECT_LONGS_EQUAL(508, muls); | ||||
| #endif | ||||
|   gttoc_(asiaJoint); | ||||
|   tictoc_getNode(asiaJointNode, asiaJoint); | ||||
|   elapsed = asiaJointNode->secs() + asiaJointNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Asia joint"); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(ADT, combine) { | ||||
|   using namespace asiaCPTs; | ||||
| 
 | ||||
|   // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
 | ||||
|   ADT pASTL = pA; | ||||
|  | @ -203,13 +197,11 @@ TEST(ADT, joint) { | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // test Inference with joint
 | ||||
| // test Inference with joint, created using different ordering
 | ||||
| TEST(ADT, inference) { | ||||
|   DiscreteKey A(0, 2), D(1, 2),  //
 | ||||
|       B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); | ||||
| 
 | ||||
|   resetCounts(); | ||||
|   gttic_(infCPTs); | ||||
|   ADT pA = create(A % "99/1"); | ||||
|   ADT pS = create(S % "50/50"); | ||||
|   ADT pT = create(T | A = "99/1 95/5"); | ||||
|  | @ -218,15 +210,9 @@ TEST(ADT, inference) { | |||
|   ADT pE = create((E | T, L) = "F T T T"); | ||||
|   ADT pX = create(X | E = "95/5 2/98"); | ||||
|   ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); | ||||
|   gttoc_(infCPTs); | ||||
|   tictoc_getNode(infCPTsNode, infCPTs); | ||||
|   elapsed = infCPTsNode->secs() + infCPTsNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   //  printCounts("Inference CPTs");
 | ||||
| 
 | ||||
|   // Create joint
 | ||||
|   // Create joint, note different ordering than above: different tree!
 | ||||
|   resetCounts(); | ||||
|   gttic_(asiaProd); | ||||
|   ADT joint = pA; | ||||
|   dot(joint, "Joint-Product-A"); | ||||
|   joint = apply(joint, pS, &mul); | ||||
|  | @ -248,14 +234,9 @@ TEST(ADT, inference) { | |||
| #else | ||||
|   EXPECT_LONGS_EQUAL(508, (long)muls);  // different ordering
 | ||||
| #endif | ||||
|   gttoc_(asiaProd); | ||||
|   tictoc_getNode(asiaProdNode, asiaProd); | ||||
|   elapsed = asiaProdNode->secs() + asiaProdNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Asia product"); | ||||
| 
 | ||||
|   resetCounts(); | ||||
|   gttic_(asiaSum); | ||||
|   ADT marginal = joint; | ||||
|   marginal = marginal.combine(X, &add_); | ||||
|   dot(marginal, "Joint-Sum-ADBLEST"); | ||||
|  | @ -270,10 +251,6 @@ TEST(ADT, inference) { | |||
| #else | ||||
|   EXPECT_LONGS_EQUAL(240, (long)adds); | ||||
| #endif | ||||
|   gttoc_(asiaSum); | ||||
|   tictoc_getNode(asiaSumNode, asiaSum); | ||||
|   elapsed = asiaSumNode->secs() + asiaSumNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Asia sum"); | ||||
| } | ||||
| 
 | ||||
|  | @ -281,8 +258,6 @@ TEST(ADT, inference) { | |||
| TEST(ADT, factor_graph) { | ||||
|   DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); | ||||
| 
 | ||||
|   resetCounts(); | ||||
|   gttic_(createCPTs); | ||||
|   ADT pS = create(S % "50/50"); | ||||
|   ADT pT = create(T % "95/5"); | ||||
|   ADT pL = create(L | S = "99/1 90/10"); | ||||
|  | @ -290,15 +265,9 @@ TEST(ADT, factor_graph) { | |||
|   ADT pX = create(X | E = "95/5 2/98"); | ||||
|   ADT pD = create(B | E = "1/8 7/9"); | ||||
|   ADT pB = create(B | S = "70/30 40/60"); | ||||
|   gttoc_(createCPTs); | ||||
|   tictoc_getNode(createCPTsNode, createCPTs); | ||||
|   elapsed = createCPTsNode->secs() + createCPTsNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   //  printCounts("Create CPTs");
 | ||||
| 
 | ||||
|   // Create joint
 | ||||
|   resetCounts(); | ||||
|   gttic_(asiaFG); | ||||
|   ADT fg = pS; | ||||
|   fg = apply(fg, pT, &mul); | ||||
|   fg = apply(fg, pL, &mul); | ||||
|  | @ -312,14 +281,9 @@ TEST(ADT, factor_graph) { | |||
| #else | ||||
|   EXPECT_LONGS_EQUAL(188, (long)muls); | ||||
| #endif | ||||
|   gttoc_(asiaFG); | ||||
|   tictoc_getNode(asiaFGNode, asiaFG); | ||||
|   elapsed = asiaFGNode->secs() + asiaFGNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Asia FG"); | ||||
| 
 | ||||
|   resetCounts(); | ||||
|   gttic_(marg); | ||||
|   fg = fg.combine(X, &add_); | ||||
|   dot(fg, "Marginalized-6X"); | ||||
|   fg = fg.combine(T, &add_); | ||||
|  | @ -335,83 +299,54 @@ TEST(ADT, factor_graph) { | |||
| #else | ||||
|   LONGS_EQUAL(62, adds); | ||||
| #endif | ||||
|   gttoc_(marg); | ||||
|   tictoc_getNode(margNode, marg); | ||||
|   elapsed = margNode->secs() + margNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("marginalize"); | ||||
| 
 | ||||
|   // BLESTX
 | ||||
| 
 | ||||
|   // Eliminate X
 | ||||
|   resetCounts(); | ||||
|   gttic_(elimX); | ||||
|   ADT fE = pX; | ||||
|   dot(fE, "Eliminate-01-fEX"); | ||||
|   fE = fE.combine(X, &add_); | ||||
|   dot(fE, "Eliminate-02-fE"); | ||||
|   gttoc_(elimX); | ||||
|   tictoc_getNode(elimXNode, elimX); | ||||
|   elapsed = elimXNode->secs() + elimXNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Eliminate X"); | ||||
| 
 | ||||
|   // Eliminate T
 | ||||
|   resetCounts(); | ||||
|   gttic_(elimT); | ||||
|   ADT fLE = pT; | ||||
|   fLE = apply(fLE, pE, &mul); | ||||
|   dot(fLE, "Eliminate-03-fLET"); | ||||
|   fLE = fLE.combine(T, &add_); | ||||
|   dot(fLE, "Eliminate-04-fLE"); | ||||
|   gttoc_(elimT); | ||||
|   tictoc_getNode(elimTNode, elimT); | ||||
|   elapsed = elimTNode->secs() + elimTNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Eliminate T"); | ||||
| 
 | ||||
|   // Eliminate S
 | ||||
|   resetCounts(); | ||||
|   gttic_(elimS); | ||||
|   ADT fBL = pS; | ||||
|   fBL = apply(fBL, pL, &mul); | ||||
|   fBL = apply(fBL, pB, &mul); | ||||
|   dot(fBL, "Eliminate-05-fBLS"); | ||||
|   fBL = fBL.combine(S, &add_); | ||||
|   dot(fBL, "Eliminate-06-fBL"); | ||||
|   gttoc_(elimS); | ||||
|   tictoc_getNode(elimSNode, elimS); | ||||
|   elapsed = elimSNode->secs() + elimSNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Eliminate S"); | ||||
| 
 | ||||
|   // Eliminate E
 | ||||
|   resetCounts(); | ||||
|   gttic_(elimE); | ||||
|   ADT fBL2 = fE; | ||||
|   fBL2 = apply(fBL2, fLE, &mul); | ||||
|   fBL2 = apply(fBL2, pD, &mul); | ||||
|   dot(fBL2, "Eliminate-07-fBLE"); | ||||
|   fBL2 = fBL2.combine(E, &add_); | ||||
|   dot(fBL2, "Eliminate-08-fBL2"); | ||||
|   gttoc_(elimE); | ||||
|   tictoc_getNode(elimENode, elimE); | ||||
|   elapsed = elimENode->secs() + elimENode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Eliminate E"); | ||||
| 
 | ||||
|   // Eliminate L
 | ||||
|   resetCounts(); | ||||
|   gttic_(elimL); | ||||
|   ADT fB = fBL; | ||||
|   fB = apply(fB, fBL2, &mul); | ||||
|   dot(fB, "Eliminate-09-fBL"); | ||||
|   fB = fB.combine(L, &add_); | ||||
|   dot(fB, "Eliminate-10-fB"); | ||||
|   gttoc_(elimL); | ||||
|   tictoc_getNode(elimLNode, elimL); | ||||
|   elapsed = elimLNode->secs() + elimLNode->wall(); | ||||
|   tictoc_reset_(); | ||||
|   printCounts("Eliminate L"); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -22,7 +22,10 @@ | |||
| #include <gtsam/base/serializationTestHelpers.h> | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/discrete/DiscreteDistribution.h> | ||||
| #include <gtsam/discrete/DiscreteFactor.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| #include <gtsam/inference/Key.h> | ||||
| #include <gtsam/inference/Ordering.h> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
|  | @ -33,25 +36,24 @@ TEST(DecisionTreeFactor, ConstructorsMatch) { | |||
|   DiscreteKey X(0, 2), Y(1, 3); | ||||
| 
 | ||||
|   // Create with vector and with string
 | ||||
|   const std::vector<double> table {2, 5, 3, 6, 4, 7}; | ||||
|   const std::vector<double> table{2, 5, 3, 6, 4, 7}; | ||||
|   DecisionTreeFactor f1({X, Y}, table); | ||||
|   DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7"); | ||||
|   EXPECT(assert_equal(f1, f2)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DecisionTreeFactor, constructors) | ||||
| { | ||||
| TEST(DecisionTreeFactor, constructors) { | ||||
|   // Declare a bunch of keys
 | ||||
|   DiscreteKey X(0,2), Y(1,3), Z(2,2); | ||||
|   DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); | ||||
| 
 | ||||
|   // Create factors
 | ||||
|   DecisionTreeFactor f1(X, {2, 8}); | ||||
|   DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); | ||||
|   DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); | ||||
|   EXPECT_LONGS_EQUAL(1,f1.size()); | ||||
|   EXPECT_LONGS_EQUAL(2,f2.size()); | ||||
|   EXPECT_LONGS_EQUAL(3,f3.size()); | ||||
|   EXPECT_LONGS_EQUAL(1, f1.size()); | ||||
|   EXPECT_LONGS_EQUAL(2, f2.size()); | ||||
|   EXPECT_LONGS_EQUAL(3, f3.size()); | ||||
| 
 | ||||
|   DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}}; | ||||
|   EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9); | ||||
|  | @ -70,7 +72,7 @@ TEST( DecisionTreeFactor, constructors) | |||
| /* ************************************************************************* */ | ||||
| TEST(DecisionTreeFactor, Error) { | ||||
|   // Declare a bunch of keys
 | ||||
|   DiscreteKey X(0,2), Y(1,3), Z(2,2); | ||||
|   DiscreteKey X(0, 2), Y(1, 3), Z(2, 2); | ||||
| 
 | ||||
|   // Create factors
 | ||||
|   DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); | ||||
|  | @ -104,9 +106,8 @@ TEST(DecisionTreeFactor, multiplication) { | |||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST( DecisionTreeFactor, sum_max) | ||||
| { | ||||
|   DiscreteKey v0(0,3), v1(1,2); | ||||
| TEST(DecisionTreeFactor, sum_max) { | ||||
|   DiscreteKey v0(0, 3), v1(1, 2); | ||||
|   DecisionTreeFactor f1(v0 & v1, "1 2  3 4  5 6"); | ||||
| 
 | ||||
|   DecisionTreeFactor expected(v1, "9 12"); | ||||
|  | @ -165,22 +166,85 @@ TEST(DecisionTreeFactor, Prune) { | |||
|       "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " | ||||
|       "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); | ||||
| 
 | ||||
|   DecisionTreeFactor expected3( | ||||
|       D & C & B & A, | ||||
|       "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " | ||||
|       "0.999952870000 1.0 1.0 1.0 1.0"); | ||||
|   DecisionTreeFactor expected3(D & C & B & A, | ||||
|                                "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " | ||||
|                                "0.999952870000 1.0 1.0 1.0 1.0"); | ||||
|   maxNrAssignments = 5; | ||||
|   auto pruned3 = factor.prune(maxNrAssignments); | ||||
|   EXPECT(assert_equal(expected3, pruned3)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Asia Bayes Network
 | ||||
| /* ************************************************************************** */ | ||||
| 
 | ||||
| #define DISABLE_DOT | ||||
| 
 | ||||
| void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) { | ||||
| #ifndef DISABLE_DOT | ||||
|   std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"}; | ||||
|   auto formatter = [&](Key key) { return names[key]; }; | ||||
|   f.dot(filename, formatter, true); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| /** Convert Signature into CPT */ | ||||
| DecisionTreeFactor create(const Signature& signature) { | ||||
|   DecisionTreeFactor p(signature.discreteKeys(), signature.cpt()); | ||||
|   return p; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // test Asia Joint
 | ||||
| TEST(DecisionTreeFactor, joint) { | ||||
|   DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), | ||||
|       D(7, 2); | ||||
| 
 | ||||
|   gttic_(asiaCPTs); | ||||
|   DecisionTreeFactor pA = create(A % "99/1"); | ||||
|   DecisionTreeFactor pS = create(S % "50/50"); | ||||
|   DecisionTreeFactor pT = create(T | A = "99/1 95/5"); | ||||
|   DecisionTreeFactor pL = create(L | S = "99/1 90/10"); | ||||
|   DecisionTreeFactor pB = create(B | S = "70/30 40/60"); | ||||
|   DecisionTreeFactor pE = create((E | T, L) = "F T T T"); | ||||
|   DecisionTreeFactor pX = create(X | E = "95/5 2/98"); | ||||
|   DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); | ||||
| 
 | ||||
|   // Create joint
 | ||||
|   gttic_(asiaJoint); | ||||
|   DecisionTreeFactor joint = pA; | ||||
|   maybeSaveDotFile(joint, "Asia-A"); | ||||
|   joint = joint * pS; | ||||
|   maybeSaveDotFile(joint, "Asia-AS"); | ||||
|   joint = joint * pT; | ||||
|   maybeSaveDotFile(joint, "Asia-AST"); | ||||
|   joint = joint * pL; | ||||
|   maybeSaveDotFile(joint, "Asia-ASTL"); | ||||
|   joint = joint * pB; | ||||
|   maybeSaveDotFile(joint, "Asia-ASTLB"); | ||||
|   joint = joint * pE; | ||||
|   maybeSaveDotFile(joint, "Asia-ASTLBE"); | ||||
|   joint = joint * pX; | ||||
|   maybeSaveDotFile(joint, "Asia-ASTLBEX"); | ||||
|   joint = joint * pD; | ||||
|   maybeSaveDotFile(joint, "Asia-ASTLBEXD"); | ||||
| 
 | ||||
|   // Check that discrete keys are as expected
 | ||||
|   EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D})); | ||||
| 
 | ||||
|   // Check that summing out variables maintains the keys even if merged, as is
 | ||||
|   // the case with S.
 | ||||
|   auto noAB = joint.sum(Ordering{A.first, B.first}); | ||||
|   EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D})); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| TEST(DecisionTreeFactor, DotWithNames) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   DecisionTreeFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; | ||||
| 
 | ||||
|   for (bool showZero:{true, false}) { | ||||
|   for (bool showZero : {true, false}) { | ||||
|     string actual = f.dot(formatter, showZero); | ||||
|     // pretty weak test, as ids are pointers and not stable across platforms.
 | ||||
|     string expected = "digraph G {"; | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ namespace gtsam { | |||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| static void checkKeys(const KeyVector& continuousKeys, | ||||
|                       std::vector<NonlinearFactorValuePair>& pairs) { | ||||
|                       const std::vector<NonlinearFactorValuePair>& pairs) { | ||||
|   KeySet factor_keys_set; | ||||
|   for (const auto& pair : pairs) { | ||||
|     auto f = pair.first; | ||||
|  | @ -55,14 +55,9 @@ HybridNonlinearFactor::HybridNonlinearFactor( | |||
| /* *******************************************************************************/ | ||||
| HybridNonlinearFactor::HybridNonlinearFactor( | ||||
|     const KeyVector& continuousKeys, const DiscreteKey& discreteKey, | ||||
|     const std::vector<NonlinearFactorValuePair>& factors) | ||||
|     const std::vector<NonlinearFactorValuePair>& pairs) | ||||
|     : Base(continuousKeys, {discreteKey}) { | ||||
|   std::vector<NonlinearFactorValuePair> pairs; | ||||
|   KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); | ||||
|   KeySet factor_keys_set; | ||||
|   for (auto&& [f, val] : factors) { | ||||
|     pairs.emplace_back(f, val); | ||||
|   } | ||||
|   checkKeys(continuousKeys, pairs); | ||||
|   factors_ = FactorValuePairs({discreteKey}, pairs); | ||||
| } | ||||
|  |  | |||
|  | @ -106,11 +106,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { | |||
|    * | ||||
|    * @param continuousKeys Vector of keys for continuous factors. | ||||
|    * @param discreteKey The discrete key for the "mode", indexing components. | ||||
|    * @param factors Vector of gaussian factor-scalar pairs, one per mode. | ||||
|    * @param pairs Vector of gaussian factor-scalar pairs, one per mode. | ||||
|    */ | ||||
|   HybridNonlinearFactor(const KeyVector& continuousKeys, | ||||
|                         const DiscreteKey& discreteKey, | ||||
|                         const std::vector<NonlinearFactorValuePair>& factors); | ||||
|                         const std::vector<NonlinearFactorValuePair>& pairs); | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct a new HybridNonlinearFactor on a several discrete keys M, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue