diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ebc02c1b5..1ba52642c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -548,8 +548,16 @@ namespace gtsam { DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( const DiscreteValues& assignment) const { ADT restricted_tree = ADT::restrict(assignment); - return std::make_shared(this->discreteKeys(), - restricted_tree); + DiscreteKeys restricted_keys = this->discreteKeys(); + for (auto&& kv : assignment) { + Key key = kv.first; + // Remove the key from the keys list + restricted_keys.erase( + std::remove_if(restricted_keys.begin(), restricted_keys.end(), + [key](const DiscreteKey& k) { return k.first == key; }), + restricted_keys.end()); + } + return std::make_shared(restricted_keys, restricted_tree); } /* ************************************************************************ */