Merge pull request #2090 from borglab/fix/decisiontreefactor-restrict

release/4.3a0
Varun Agrawal 2025-04-13 14:34:45 -04:00 committed by GitHub
commit 8a1c4120bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 68 additions and 10 deletions

View File

@ -206,8 +206,8 @@ namespace gtsam {
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
* @param addLineBreak Flag indicating if a line break should be
* added at the end. Only used at the top-level.
*/
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
@ -217,8 +217,8 @@ namespace gtsam {
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
* @param addLineBreak Flag indicating if a line break should be
* added at the end. Only used at the top-level.
*/
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;

View File

@ -548,8 +548,20 @@ namespace gtsam {
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
const DiscreteValues& assignment) const {
ADT restricted_tree = ADT::restrict(assignment);
return std::make_shared<DecisionTreeFactor>(this->discreteKeys(),
restricted_tree);
// Get all the keys that are not restricted by the assignment
// This ensures that the new restricted factor doesn't have keys
// for which the information has been removed.
DiscreteKeys restricted_keys = this->discreteKeys();
for (auto&& kv : assignment) {
Key key = kv.first;
// Remove the key from the keys list
restricted_keys.erase(
std::remove_if(restricted_keys.begin(), restricted_keys.end(),
[key](const DiscreteKey& k) { return k.first == key; }),
restricted_keys.end());
}
// Create the restricted factor with the appropriate keys and tree.
return std::make_shared<DecisionTreeFactor>(restricted_keys, restricted_tree);
}
/* ************************************************************************ */

View File

@ -172,6 +172,45 @@ TEST(DecisionTreeFactor, enumerate) {
EXPECT(actual == expected);
}
/* ************************************************************************* */
// Test if restricting a factor based on DiscreteValues works.
TEST(DecisionTreeFactor, Restrict) {
// Test for restricting a single value from multiple values.
DiscreteKey A(12, 2), B(5, 3);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
DiscreteValues fixedValues = {{A.first, 1}};
DecisionTreeFactor restricted_f1 =
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));
DecisionTreeFactor expected_f1(B, "4 5 6");
EXPECT(assert_equal(expected_f1, restricted_f1));
// Test for restricting a multiple value from multiple values.
DiscreteKey C(91, 2);
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
fixedValues = {{A.first, 0}, {B.first, 2}};
DecisionTreeFactor restricted_f2 =
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));
DecisionTreeFactor expected_f2(C, "5 6");
EXPECT(assert_equal(expected_f2, restricted_f2));
// Edge case of restricting a single value when it is the only value.
DecisionTreeFactor f3(A, "50 100");
fixedValues = {{A.first, 1}}; // select 100
DecisionTreeFactor restricted_f3 =
*std::static_pointer_cast<DecisionTreeFactor>(f3.restrict(fixedValues));
EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size());
// There should only be 1 value which is 100
EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues());
EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves());
EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9);
}
namespace pruning_fixture {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);

View File

@ -223,7 +223,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @note If marginal greater than this threshold, the mode gets assigned that
* value and is considered "dead" for hybrid elimination. The mode can then be
* removed since it only has a single possible assignment.
*
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(size_t maxNrLeaves,

View File

@ -23,7 +23,6 @@
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
namespace gtsam {
/* ************************************************************************* */
@ -237,12 +236,20 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
result.push_back(hf->restrict(discreteValues));
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
result.push_back(df->restrict(discreteValues));
auto restricted_df = df->restrict(discreteValues);
// In the case where all the discrete values in the factor
// have been selected, we get a factor without any keys,
// and default values of 0.5.
// Since this factor no longer adds any information, we ignore it to make
// inference faster.
if (restricted_df->discreteKeys().size() > 0) {
result.push_back(restricted_df);
}
} else {
result.push_back(f); // Everything else is just added as is
}
}
return result;
}