diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 8be314c4e..e471cb02f 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -112,7 +112,7 @@ std::function &, double)> prunerFunc( DiscreteValues::CartesianProduct(set_diff); for (const DiscreteValues &assignment : assignments) { DiscreteValues augmented_values(values); - augmented_values.insert(assignment.begin(), assignment.end()); + augmented_values.insert(assignment); // If any one of the sub-branches are non-zero, // we need this probability. diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index ff896041e..efe65bc31 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -104,6 +104,28 @@ class GTSAM_EXPORT HybridValues { * @param j The index with which the value will be associated. */ void insert(Key j, const Vector& value) { continuous_.insert(j, value); } + /** Insert all continuous values from \c values. Throws an invalid_argument + * exception if any keys to be inserted are already used. */ + HybridValues& insert(const VectorValues& values) { + continuous_.insert(values); + return *this; + } + + /** Insert all discrete values from \c values. Throws an invalid_argument + * exception if any keys to be inserted are already used. */ + HybridValues& insert(const DiscreteValues& values) { + discrete_.insert(values); + return *this; + } + + /** Insert all values from \c values. Throws an invalid_argument exception if + * any keys to be inserted are already used. */ + HybridValues& insert(const HybridValues& values) { + continuous_.insert(values.continuous()); + discrete_.insert(values.discrete()); + return *this; + } + // TODO(Shangjie)- insert_or_assign() , similar to Values.h /** @@ -118,10 +140,33 @@ class GTSAM_EXPORT HybridValues { */ Vector& at(Key j) { return continuous_.at(j); }; - /** For all key/value pairs in \c values, replace values with corresponding keys in this class - * with those in \c values. Throws std::out_of_range if any keys in \c values are not present - * in this class. */ - void update(const VectorValues& values) { continuous_.update(values); } + /** For all key/value pairs in \c values, replace continuous values with + * corresponding keys in this object with those in \c values. Throws + * std::out_of_range if any keys in \c values are not present in this object. + */ + HybridValues& update(const VectorValues& values) { + continuous_.update(values); + return *this; + } + + /** For all key/value pairs in \c values, replace discrete values with + * corresponding keys in this object with those in \c values. Throws + * std::out_of_range if any keys in \c values are not present in this object. + */ + HybridValues& update(const DiscreteValues& values) { + discrete_.update(values); + return *this; + } + + /** For all key/value pairs in \c values, replace all values with + * corresponding keys in this object with those in \c values. Throws + * std::out_of_range if any keys in \c values are not present in this object. + */ + HybridValues& update(const HybridValues& values) { + continuous_.update(values.continuous()); + discrete_.update(values.discrete()); + return *this; + } /// @} /// @name Wrapper support diff --git a/gtsam/hybrid/tests/testHybridValues.cpp b/gtsam/hybrid/tests/testHybridValues.cpp index 6f510601d..02e1cb733 100644 --- a/gtsam/hybrid/tests/testHybridValues.cpp +++ b/gtsam/hybrid/tests/testHybridValues.cpp @@ -32,22 +32,45 @@ using namespace std; using namespace gtsam; -TEST(HybridValues, basics) { +static const HybridValues kExample{{{99, Vector2(2, 3)}}, {{100, 3}}}; + +/* ************************************************************************* */ +TEST(HybridValues, Basics) { HybridValues values; values.insert(99, Vector2(2, 3)); values.insert(100, 3); + EXPECT(assert_equal(kExample, values)); EXPECT(assert_equal(values.at(99), Vector2(2, 3))); EXPECT(assert_equal(values.atDiscrete(100), int(3))); - values.print(); - HybridValues values2; values2.insert(100, 3); values2.insert(99, Vector2(2, 3)); - EXPECT(assert_equal(values2, values)); + EXPECT(assert_equal(kExample, values2)); +} - values2.insert(98, Vector2(2, 3)); - EXPECT(!assert_equal(values2, values)); +/* ************************************************************************* */ +// Check insert +TEST(HybridValues, Insert) { + HybridValues actual; + EXPECT(assert_equal({{}, {{100, 3}}}, // + actual.insert(DiscreteValues{{100, 3}}))); + EXPECT(assert_equal(kExample, // + actual.insert(VectorValues{{99, Vector2(2, 3)}}))); + HybridValues actual2; + EXPECT(assert_equal(kExample, actual2.insert(kExample))); +} + +/* ************************************************************************* */ +// Check update. +TEST(HybridValues, Update) { + HybridValues actual(kExample); + EXPECT(assert_equal({{{99, Vector2(2, 3)}}, {{100, 2}}}, + actual.update(DiscreteValues{{100, 2}}))); + EXPECT(assert_equal({{{99, Vector1(4)}}, {{100, 2}}}, + actual.update(VectorValues{{99, Vector1(4)}}))); + HybridValues actual2(kExample); + EXPECT(assert_equal(kExample, actual2.update(kExample))); } /* ************************************************************************* */