From bfb865c12e5ab9855c70d1b0e55dca9a0ff0f68f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 00:00:06 -0400 Subject: [PATCH 1/7] DiscreteKeys serialization --- gtsam/discrete/DiscreteKey.h | 31 ++++++++++++++++++++- gtsam/discrete/tests/testDiscreteFactor.cpp | 20 +++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 297d5570d..ec76f5941 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -79,8 +80,36 @@ namespace gtsam { } } + bool equals(const DiscreteKeys& other, double tol = 0) const { + if (this->size() != other.size()) { + return false; + } + + for (size_t i = 0; i < this->size(); i++) { + if (this->at(i).first != other.at(i).first || + this->at(i).second != other.at(i).second) { + return false; + } + } + return true; + } + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "DiscreteKeys", + boost::serialization::base_object>(*this)); + } + }; // DiscreteKeys /// Create a list from two keys GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); -} + + // traits + template <> + struct traits : public Testable {}; + + } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteFactor.cpp b/gtsam/discrete/tests/testDiscreteFactor.cpp index 8681cf7eb..db0491c9d 100644 --- a/gtsam/discrete/tests/testDiscreteFactor.cpp +++ b/gtsam/discrete/tests/testDiscreteFactor.cpp @@ -16,14 +16,29 @@ * @author Duy-Nguyen Ta */ -#include -#include #include +#include +#include +#include + #include using namespace boost::assign; using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + +/* ************************************************************************* */ +TEST(DisreteKeys, Serialization) { + DiscreteKeys keys; + keys& DiscreteKey(0, 2); + keys& DiscreteKey(1, 3); + keys& DiscreteKey(2, 4); + + EXPECT(equalsObj(keys)); + EXPECT(equalsXML(keys)); + EXPECT(equalsBinary(keys)); +} /* ************************************************************************* */ int main() { @@ -31,4 +46,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - From eb5092897b442058d0e37e3eb506bf470807636a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 00:03:31 -0400 Subject: [PATCH 2/7] add serialization for HybridFactor and HybridConditional --- gtsam/hybrid/HybridConditional.h | 9 +++++++++ gtsam/hybrid/HybridFactor.h | 14 ++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 96ea6d969..b43bb9945 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional /// Get the type-erased pointer to the inner type boost::shared_ptr inner() { return inner_; } + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } + }; // HybridConditional // traits diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 13dc2e6e6..b3cdc231b 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { bool isContinuous_ = false; bool isHybrid_ = false; + // TODO(Varun) remove size_t nrContinuous_ = 0; protected: @@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor { const KeyVector &continuousKeys() const { return continuousKeys_; } /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(isDiscrete_); + ar &BOOST_SERIALIZATION_NVP(isContinuous_); + ar &BOOST_SERIALIZATION_NVP(isHybrid_); + ar &BOOST_SERIALIZATION_NVP(discreteKeys_); + ar &BOOST_SERIALIZATION_NVP(continuousKeys_); + } }; // HybridFactor From 8692ae63eaa5eafc4465c55f7392e9cf02cdc692 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 00:03:55 -0400 Subject: [PATCH 3/7] Make HybridBayesNet testable and add serialization --- gtsam/hybrid/HybridBayesNet.h | 54 ++++++++++++++++++++--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 15 +++++++ 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 616ea0698..e84103a50 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using shared_ptr = boost::shared_ptr; using sharedConditional = boost::shared_ptr; + /// @name Standard Constructors + /// @{ + /** Construct empty bayes net */ HybridBayesNet() = default; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + /// @} + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This &bn, double tol = 1e-9) const { + return Base::equals(bn, tol); + } + + /// print graph + void print( + const std::string &s = "", + const KeyFormatter &formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + + /// @} + /// @name Standard Interface + /// @{ /// Add HybridConditional to Bayes Net using Base::add; @@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ GaussianBayesNet choose(const DiscreteValues &assignment) const; - /// Solve the HybridBayesNet by back-substitution. - /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and - /// put this method there? + /** + * @brief Solve the HybridBayesNet by first computing the MPE of all the + * discrete variables and then optimizing the continuous variables based on + * the MPE assignment. + * + * @return HybridValues + */ HybridValues optimize() const; /** @@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @return Values */ VectorValues optimize(const DiscreteValues &assignment) const; + + /// Prune the Hybrid Bayes Net given the discrete decision tree. + HybridBayesNet prune( + const DecisionTreeFactor::shared_ptr &discreteFactor) const; + + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index c7516c0f6..bf9385bc4 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,6 +18,7 @@ * @date December 2021 */ +#include #include #include @@ -28,6 +29,8 @@ using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; @@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridBayesNet, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + /* ************************************************************************* */ int main() { TestResult tr; From b16b05ea2caef84a69bb5cda82f2472f350811d4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 00:04:19 -0400 Subject: [PATCH 4/7] Make HybridBayesTree testable and add serialization --- gtsam/hybrid/HybridBayesTree.h | 12 ++++++++++++ gtsam/hybrid/tests/testHybridBayesTree.cpp | 15 +++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 361fbe86f..3fa344d4d 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { VectorValues optimize(const DiscreteValues& assignment) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + /** * @brief Class for Hybrid Bayes tree orphan subtrees. * diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index d457e6b74..0908b8cb5 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -16,6 +16,7 @@ * @date August 2022 */ +#include #include #include #include @@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous())); } +/* ****************************************************************************/ +// Test HybridBayesTree serialization. +TEST(HybridBayesTree, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree hbt = + *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); + + using namespace gtsam::serializationTestHelpers; + EXPECT(equalsObj(hbt)); + EXPECT(equalsXML(hbt)); + EXPECT(equalsBinary(hbt)); +} + /* ************************************************************************* */ int main() { TestResult tr; From c6ebbdc70812715d6f578c6f5fa5d6eb72b33cc4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 00:06:29 -0400 Subject: [PATCH 5/7] add serialization test for GaussianBayesNet --- gtsam/linear/tests/testGaussianBayesNet.cpp | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index 2b125265f..fc9e52b5c 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -350,6 +350,37 @@ TEST(GaussianBayesNet, Dot) { "}"); } +#include +using namespace gtsam::serializationTestHelpers; + +/* ****************************************************************************/ +// Test GaussianBayesNet serialization. +TEST(GaussianBayesNet, Serialization) { + // Create an arbitrary Bayes Net + GaussianBayesNet gbn; + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3, + (Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4, + (Matrix2() << 11.0, 12.0, 13.0, 14.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(), + 2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4, + (Matrix2() << 25.0, 26.0, 27.0, 28.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(), + 3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(), + 4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished())); + + EXPECT(equalsObj(gbn)); + EXPECT(equalsXML(gbn)); + EXPECT(equalsBinary(gbn)); +} + + /* ************************************************************************* */ int main() { TestResult tr; From ab017dfd19e8027eedb994d313b890d5118bf0f1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 10:40:48 -0400 Subject: [PATCH 6/7] move DiscreteKeys code to .cpp --- gtsam/discrete/DiscreteKey.cpp | 21 +++++++++++++++++++++ gtsam/discrete/DiscreteKey.h | 22 +++------------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 121d61103..06ed2ca3b 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -48,4 +48,25 @@ namespace gtsam { return keys & key2; } + void DiscreteKeys::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + for (auto&& dkey : *this) { + std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second + << std::endl; + } + } + + bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const { + if (this->size() != other.size()) { + return false; + } + + for (size_t i = 0; i < this->size(); i++) { + if (this->at(i).first != other.at(i).first || + this->at(i).second != other.at(i).second) { + return false; + } + } + return true; + } } diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index ec76f5941..8e0802d83 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -73,26 +73,10 @@ namespace gtsam { /// Print the keys and cardinalities. void print(const std::string& s = "", - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { - for (auto&& dkey : *this) { - std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second - << std::endl; - } - } + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - bool equals(const DiscreteKeys& other, double tol = 0) const { - if (this->size() != other.size()) { - return false; - } - - for (size_t i = 0; i < this->size(); i++) { - if (this->at(i).first != other.at(i).first || - this->at(i).second != other.at(i).second) { - return false; - } - } - return true; - } + /// Check equality to another DiscreteKeys object. + bool equals(const DiscreteKeys& other, double tol = 0) const; /** Serialization function */ friend class boost::serialization::access; From 27a9d566028c694bfd7e8297a999d9c6995e76c8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 1 Sep 2022 13:47:18 -0400 Subject: [PATCH 7/7] move GaussianBayesNet serialization test to testSerializationLinear --- gtsam/linear/tests/testGaussianBayesNet.cpp | 31 ------------------- .../linear/tests/testSerializationLinear.cpp | 27 ++++++++++++++++ 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/gtsam/linear/tests/testGaussianBayesNet.cpp b/gtsam/linear/tests/testGaussianBayesNet.cpp index fc9e52b5c..2b125265f 100644 --- a/gtsam/linear/tests/testGaussianBayesNet.cpp +++ b/gtsam/linear/tests/testGaussianBayesNet.cpp @@ -350,37 +350,6 @@ TEST(GaussianBayesNet, Dot) { "}"); } -#include -using namespace gtsam::serializationTestHelpers; - -/* ****************************************************************************/ -// Test GaussianBayesNet serialization. -TEST(GaussianBayesNet, Serialization) { - // Create an arbitrary Bayes Net - GaussianBayesNet gbn; - gbn += GaussianConditional::shared_ptr(new GaussianConditional( - 0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3, - (Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4, - (Matrix2() << 11.0, 12.0, 13.0, 14.0).finished())); - gbn += GaussianConditional::shared_ptr(new GaussianConditional( - 1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(), - 2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4, - (Matrix2() << 25.0, 26.0, 27.0, 28.0).finished())); - gbn += GaussianConditional::shared_ptr(new GaussianConditional( - 2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(), - 3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished())); - gbn += GaussianConditional::shared_ptr(new GaussianConditional( - 3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(), - 4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished())); - gbn += GaussianConditional::shared_ptr(new GaussianConditional( - 4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished())); - - EXPECT(equalsObj(gbn)); - EXPECT(equalsXML(gbn)); - EXPECT(equalsBinary(gbn)); -} - - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/linear/tests/testSerializationLinear.cpp b/gtsam/linear/tests/testSerializationLinear.cpp index 881b2830e..ee21de364 100644 --- a/gtsam/linear/tests/testSerializationLinear.cpp +++ b/gtsam/linear/tests/testSerializationLinear.cpp @@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) { EXPECT(equalsBinary(graph)); } +/* ****************************************************************************/ +TEST(Serialization, gaussian_bayes_net) { + // Create an arbitrary Bayes Net + GaussianBayesNet gbn; + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3, + (Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4, + (Matrix2() << 11.0, 12.0, 13.0, 14.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(), + 2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4, + (Matrix2() << 25.0, 26.0, 27.0, 28.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(), + 3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(), + 4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished())); + + std::string serialized = serialize(gbn); + GaussianBayesNet actual; + deserialize(serialized, actual); + EXPECT(assert_equal(gbn, actual)); +} + /* ************************************************************************* */ TEST (Serialization, gaussian_bayes_tree) { const Key x1=1, x2=2, x3=3, x4=4;