diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ba52642c..39d678ad2 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -548,6 +548,9 @@ namespace gtsam { DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( const DiscreteValues& assignment) const { ADT restricted_tree = ADT::restrict(assignment); + // Get all the keys that are not restricted by the assignment + // This ensures that the new restricted factor doesn't have keys + // for which the information has been removed. DiscreteKeys restricted_keys = this->discreteKeys(); for (auto&& kv : assignment) { Key key = kv.first; @@ -557,6 +560,7 @@ DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( [key](const DiscreteKey& k) { return k.first == key; }), restricted_keys.end()); } + // Create the restricted factor with the appropriate keys and tree. return std::make_shared(restricted_keys, restricted_tree); } diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 6ab28ccb4..11dd6a2a5 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -237,8 +237,11 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict( result.push_back(hf->restrict(discreteValues)); } else if (auto df = dynamic_pointer_cast(f)) { auto restricted_df = df->restrict(discreteValues); - // In the case where all the discrete values have been selected, - // we ignore the factor since it doesn't add any information + // In the case where all the discrete values in the factor + // have been selected, we get a factor without any keys, + // and default values of 0.5. + // Since this factor no longer adds any information, we ignore it to make + // inference faster. if (restricted_df->discreteKeys().size() > 0) { result.push_back(restricted_df); }