From 3779f5280d14f2700227e403d87f8189d018b5ad Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 18 Sep 2023 22:57:59 -0400 Subject: [PATCH] use STL to make DecisionTreeFactor::combine more efficient --- gtsam/discrete/DecisionTreeFactor.cpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 17731453a..7202e6853 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -164,15 +164,23 @@ namespace gtsam { } // create new factor, note we collect keys that are not in frontalKeys - // TODO(frank): why do we need this??? result should contain correct keys!!! + /* + Due to branch merging, the labels in `result` may be missing some keys + E.g. After branch merging, we may get a ADT like: + Leaf [2] 1.0204082 + + This is missing the key values used for branching. + */ + KeyVector difference, frontalKeys_(frontalKeys), keys_(keys()); + // Get the difference of the frontalKeys and the factor keys using set_difference + std::sort(keys_.begin(), keys_.end()); + std::sort(frontalKeys_.begin(), frontalKeys_.end()); + std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(), + frontalKeys_.end(), back_inserter(difference)); + DiscreteKeys dkeys; - for (i = 0; i < keys().size(); i++) { - Key j = keys()[i]; - // TODO(frank): inefficient! - if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != - frontalKeys.end()) - continue; - dkeys.push_back(DiscreteKey(j, cardinality(j))); + for (Key key : difference) { + dkeys.push_back(DiscreteKey(key, cardinality(key))); } return std::make_shared(dkeys, result); }