Merge pull request #2090 from borglab/fix/decisiontreefactor-restrict
commit
8a1c4120bb
|
@ -206,8 +206,8 @@ namespace gtsam {
|
|||
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||
* time, max time)
|
||||
*
|
||||
* @param addLineBreak Flag indicating if a line break should be added at
|
||||
* the end. Only used at the top-leve.
|
||||
* @param addLineBreak Flag indicating if a line break should be
|
||||
* added at the end. Only used at the top-level.
|
||||
*/
|
||||
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;
|
||||
|
||||
|
@ -217,8 +217,8 @@ namespace gtsam {
|
|||
* (CPU time, number of times, wall time, time + children in seconds, min
|
||||
* time, max time)
|
||||
*
|
||||
* @param addLineBreak Flag indicating if a line break should be added at
|
||||
* the end. Only used at the top-leve.
|
||||
* @param addLineBreak Flag indicating if a line break should be
|
||||
* added at the end. Only used at the top-level.
|
||||
*/
|
||||
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;
|
||||
|
||||
|
|
|
@ -548,8 +548,20 @@ namespace gtsam {
|
|||
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
|
||||
const DiscreteValues& assignment) const {
|
||||
ADT restricted_tree = ADT::restrict(assignment);
|
||||
return std::make_shared<DecisionTreeFactor>(this->discreteKeys(),
|
||||
restricted_tree);
|
||||
// Get all the keys that are not restricted by the assignment
|
||||
// This ensures that the new restricted factor doesn't have keys
|
||||
// for which the information has been removed.
|
||||
DiscreteKeys restricted_keys = this->discreteKeys();
|
||||
for (auto&& kv : assignment) {
|
||||
Key key = kv.first;
|
||||
// Remove the key from the keys list
|
||||
restricted_keys.erase(
|
||||
std::remove_if(restricted_keys.begin(), restricted_keys.end(),
|
||||
[key](const DiscreteKey& k) { return k.first == key; }),
|
||||
restricted_keys.end());
|
||||
}
|
||||
// Create the restricted factor with the appropriate keys and tree.
|
||||
return std::make_shared<DecisionTreeFactor>(restricted_keys, restricted_tree);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -172,6 +172,45 @@ TEST(DecisionTreeFactor, enumerate) {
|
|||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test if restricting a factor based on DiscreteValues works.
|
||||
TEST(DecisionTreeFactor, Restrict) {
|
||||
// Test for restricting a single value from multiple values.
|
||||
DiscreteKey A(12, 2), B(5, 3);
|
||||
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
|
||||
DiscreteValues fixedValues = {{A.first, 1}};
|
||||
|
||||
DecisionTreeFactor restricted_f1 =
|
||||
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));
|
||||
|
||||
DecisionTreeFactor expected_f1(B, "4 5 6");
|
||||
EXPECT(assert_equal(expected_f1, restricted_f1));
|
||||
|
||||
// Test for restricting a multiple value from multiple values.
|
||||
DiscreteKey C(91, 2);
|
||||
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
|
||||
fixedValues = {{A.first, 0}, {B.first, 2}};
|
||||
|
||||
DecisionTreeFactor restricted_f2 =
|
||||
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));
|
||||
|
||||
DecisionTreeFactor expected_f2(C, "5 6");
|
||||
EXPECT(assert_equal(expected_f2, restricted_f2));
|
||||
|
||||
// Edge case of restricting a single value when it is the only value.
|
||||
DecisionTreeFactor f3(A, "50 100");
|
||||
fixedValues = {{A.first, 1}}; // select 100
|
||||
|
||||
DecisionTreeFactor restricted_f3 =
|
||||
*std::static_pointer_cast<DecisionTreeFactor>(f3.restrict(fixedValues));
|
||||
|
||||
EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size());
|
||||
// There should only be 1 value which is 100
|
||||
EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues());
|
||||
EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves());
|
||||
EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9);
|
||||
}
|
||||
|
||||
namespace pruning_fixture {
|
||||
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
|
|
|
@ -223,7 +223,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @note If marginal greater than this threshold, the mode gets assigned that
|
||||
* value and is considered "dead" for hybrid elimination. The mode can then be
|
||||
* removed since it only has a single possible assignment.
|
||||
|
||||
*
|
||||
* @return A pruned HybridBayesNet
|
||||
*/
|
||||
HybridBayesNet prune(size_t maxNrLeaves,
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -237,7 +236,15 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
|
|||
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
result.push_back(hf->restrict(discreteValues));
|
||||
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
result.push_back(df->restrict(discreteValues));
|
||||
auto restricted_df = df->restrict(discreteValues);
|
||||
// In the case where all the discrete values in the factor
|
||||
// have been selected, we get a factor without any keys,
|
||||
// and default values of 0.5.
|
||||
// Since this factor no longer adds any information, we ignore it to make
|
||||
// inference faster.
|
||||
if (restricted_df->discreteKeys().size() > 0) {
|
||||
result.push_back(restricted_df);
|
||||
}
|
||||
} else {
|
||||
result.push_back(f); // Everything else is just added as is
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue