Merge pull request #1135 from borglab/fix/visitWith
						commit
						e5fb4cdaa6
					
				|  | @ -33,6 +33,8 @@ namespace gtsam { | |||
| template <class L> | ||||
| class Assignment : public std::map<L, size_t> { | ||||
|  public: | ||||
|   using std::map<L, size_t>::operator=; | ||||
| 
 | ||||
|   void print(const std::string& s = "Assignment: ") const { | ||||
|     std::cout << s << ": "; | ||||
|     for (const typename Assignment::value_type& keyValue : *this) | ||||
|  |  | |||
|  | @ -666,8 +666,13 @@ namespace gtsam { | |||
|       if (!choice) | ||||
|         throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); | ||||
|       for (size_t i = 0; i < choice->nrChoices(); i++) { | ||||
|         choices[choice->label()] = i;    // Set assignment for label to i
 | ||||
|         choices[choice->label()] = i;  // Set assignment for label to i
 | ||||
| 
 | ||||
|         (*this)(choice->branches()[i]);  // recurse!
 | ||||
| 
 | ||||
|         // Remove the choice so we are backtracking
 | ||||
|         auto choice_it = choices.find(choice->label()); | ||||
|         choices.erase(choice_it); | ||||
|       } | ||||
|     } | ||||
|   }; | ||||
|  |  | |||
|  | @ -346,6 +346,44 @@ TEST(DecisionTree, visitWith) { | |||
|   EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test visit, with Choices argument.
 | ||||
| TEST(DecisionTree, VisitWithPruned) { | ||||
|   // Create pruned tree
 | ||||
|   std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2); | ||||
|   std::vector<std::pair<string, size_t>> labels = {C, B, A}; | ||||
|   std::vector<int> nodes = {0, 0, 2, 3, 4, 4, 6, 7}; | ||||
|   DT tree(labels, nodes); | ||||
| 
 | ||||
|   std::vector<Assignment<string>> choices; | ||||
|   auto func = [&](const Assignment<string>& choice, const int& d) { | ||||
|     choices.push_back(choice); | ||||
|   }; | ||||
|   tree.visitWith(func); | ||||
| 
 | ||||
|   EXPECT_LONGS_EQUAL(6, choices.size()); | ||||
| 
 | ||||
|   Assignment<string> expectedAssignment; | ||||
| 
 | ||||
|   expectedAssignment = {{"B", 0}, {"C", 0}}; | ||||
|   EXPECT(expectedAssignment == choices.at(0)); | ||||
| 
 | ||||
|   expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}}; | ||||
|   EXPECT(expectedAssignment == choices.at(1)); | ||||
| 
 | ||||
|   expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}}; | ||||
|   EXPECT(expectedAssignment == choices.at(2)); | ||||
| 
 | ||||
|   expectedAssignment = {{"B", 0}, {"C", 1}}; | ||||
|   EXPECT(expectedAssignment == choices.at(3)); | ||||
| 
 | ||||
|   expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}}; | ||||
|   EXPECT(expectedAssignment == choices.at(4)); | ||||
| 
 | ||||
|   expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}}; | ||||
|   EXPECT(expectedAssignment == choices.at(5)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| // Test fold.
 | ||||
| TEST(DecisionTree, fold) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue