Merge pull request #2090 from borglab/fix/decisiontreefactor-restrict
commit
8a1c4120bb
|
@ -206,8 +206,8 @@ namespace gtsam {
|
||||||
* (CPU time, number of times, wall time, time + children in seconds, min
|
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||||
* time, max time)
|
* time, max time)
|
||||||
*
|
*
|
||||||
* @param addLineBreak Flag indicating if a line break should be added at
|
* @param addLineBreak Flag indicating if a line break should be
|
||||||
* the end. Only used at the top-leve.
|
* added at the end. Only used at the top-level.
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
|
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
|
||||||
|
|
||||||
|
@ -217,8 +217,8 @@ namespace gtsam {
|
||||||
* (CPU time, number of times, wall time, time + children in seconds, min
|
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||||
* time, max time)
|
* time, max time)
|
||||||
*
|
*
|
||||||
* @param addLineBreak Flag indicating if a line break should be added at
|
* @param addLineBreak Flag indicating if a line break should be
|
||||||
* the end. Only used at the top-leve.
|
* added at the end. Only used at the top-level.
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;
|
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;
|
||||||
|
|
||||||
|
|
|
@ -548,8 +548,20 @@ namespace gtsam {
|
||||||
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
|
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
|
||||||
const DiscreteValues& assignment) const {
|
const DiscreteValues& assignment) const {
|
||||||
ADT restricted_tree = ADT::restrict(assignment);
|
ADT restricted_tree = ADT::restrict(assignment);
|
||||||
return std::make_shared<DecisionTreeFactor>(this->discreteKeys(),
|
// Get all the keys that are not restricted by the assignment
|
||||||
restricted_tree);
|
// This ensures that the new restricted factor doesn't have keys
|
||||||
|
// for which the information has been removed.
|
||||||
|
DiscreteKeys restricted_keys = this->discreteKeys();
|
||||||
|
for (auto&& kv : assignment) {
|
||||||
|
Key key = kv.first;
|
||||||
|
// Remove the key from the keys list
|
||||||
|
restricted_keys.erase(
|
||||||
|
std::remove_if(restricted_keys.begin(), restricted_keys.end(),
|
||||||
|
[key](const DiscreteKey& k) { return k.first == key; }),
|
||||||
|
restricted_keys.end());
|
||||||
|
}
|
||||||
|
// Create the restricted factor with the appropriate keys and tree.
|
||||||
|
return std::make_shared<DecisionTreeFactor>(restricted_keys, restricted_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -172,6 +172,45 @@ TEST(DecisionTreeFactor, enumerate) {
|
||||||
EXPECT(actual == expected);
|
EXPECT(actual == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test if restricting a factor based on DiscreteValues works.
|
||||||
|
TEST(DecisionTreeFactor, Restrict) {
|
||||||
|
// Test for restricting a single value from multiple values.
|
||||||
|
DiscreteKey A(12, 2), B(5, 3);
|
||||||
|
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
||||||
|
DiscreteValues fixedValues = {{A.first, 1}};
|
||||||
|
|
||||||
|
DecisionTreeFactor restricted_f1 =
|
||||||
|
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));
|
||||||
|
|
||||||
|
DecisionTreeFactor expected_f1(B, "4 5 6");
|
||||||
|
EXPECT(assert_equal(expected_f1, restricted_f1));
|
||||||
|
|
||||||
|
// Test for restricting a multiple value from multiple values.
|
||||||
|
DiscreteKey C(91, 2);
|
||||||
|
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
|
||||||
|
fixedValues = {{A.first, 0}, {B.first, 2}};
|
||||||
|
|
||||||
|
DecisionTreeFactor restricted_f2 =
|
||||||
|
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));
|
||||||
|
|
||||||
|
DecisionTreeFactor expected_f2(C, "5 6");
|
||||||
|
EXPECT(assert_equal(expected_f2, restricted_f2));
|
||||||
|
|
||||||
|
// Edge case of restricting a single value when it is the only value.
|
||||||
|
DecisionTreeFactor f3(A, "50 100");
|
||||||
|
fixedValues = {{A.first, 1}}; // select 100
|
||||||
|
|
||||||
|
DecisionTreeFactor restricted_f3 =
|
||||||
|
*std::static_pointer_cast<DecisionTreeFactor>(f3.restrict(fixedValues));
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size());
|
||||||
|
// There should only be 1 value which is 100
|
||||||
|
EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues());
|
||||||
|
EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves());
|
||||||
|
EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
namespace pruning_fixture {
|
namespace pruning_fixture {
|
||||||
|
|
||||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||||
|
|
|
@ -223,7 +223,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @note If marginal greater than this threshold, the mode gets assigned that
|
* @note If marginal greater than this threshold, the mode gets assigned that
|
||||||
* value and is considered "dead" for hybrid elimination. The mode can then be
|
* value and is considered "dead" for hybrid elimination. The mode can then be
|
||||||
* removed since it only has a single possible assignment.
|
* removed since it only has a single possible assignment.
|
||||||
|
*
|
||||||
* @return A pruned HybridBayesNet
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
HybridBayesNet prune(size_t maxNrLeaves,
|
HybridBayesNet prune(size_t maxNrLeaves,
|
||||||
|
|
|
@ -23,7 +23,6 @@
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -237,12 +236,20 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
|
||||||
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||||
result.push_back(hf->restrict(discreteValues));
|
result.push_back(hf->restrict(discreteValues));
|
||||||
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
result.push_back(df->restrict(discreteValues));
|
auto restricted_df = df->restrict(discreteValues);
|
||||||
|
// In the case where all the discrete values in the factor
|
||||||
|
// have been selected, we get a factor without any keys,
|
||||||
|
// and default values of 0.5.
|
||||||
|
// Since this factor no longer adds any information, we ignore it to make
|
||||||
|
// inference faster.
|
||||||
|
if (restricted_df->discreteKeys().size() > 0) {
|
||||||
|
result.push_back(restricted_df);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
result.push_back(f); // Everything else is just added as is
|
result.push_back(f); // Everything else is just added as is
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue