diff --git a/cpp/DSF.h b/cpp/DSF.h index 7a05e8e2d..eff4bdd8e 100644 --- a/cpp/DSF.h +++ b/cpp/DSF.h @@ -109,11 +109,13 @@ namespace gtsam { return partitions; } - // get the nodes in the given tree + // get the nodes in the tree with the given label Set set(const Label& label) { Set set; - BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) - if (pair.second==label) set.insert(pair.first); + BOOST_FOREACH(const KeyLabel& pair, (Tree)*this) { + if (pair.second == label || findSet(pair.second) == label) + set.insert(pair.first); + } return set; } diff --git a/cpp/testDSF.cpp b/cpp/testDSF.cpp index 7f4f2ce1f..a0ba0be7c 100644 --- a/cpp/testDSF.cpp +++ b/cpp/testDSF.cpp @@ -195,6 +195,21 @@ TEST(DSF, set) { CHECK(expected == set); } +/* ************************************************************************* */ +TEST(DSF, set2) { + DSFInt dsf; + dsf = dsf.makeSet(5); + dsf = dsf.makeSet(6); + dsf = dsf.makeSet(7); + dsf = dsf.makeUnion(5,6); + dsf = dsf.makeUnion(6,7); + set set = dsf.set(5); + LONGS_EQUAL(3, set.size()); + + std::set expected; expected += 5, 6, 7; + CHECK(expected == set); +} + /* ************************************************************************* */ int func(const int& a) { return a + 10; } TEST(DSF, map) {