From 74142e4be19fc59713bd4f180729ad7cc22c689c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 3 Jan 2023 17:37:09 -0500 Subject: [PATCH] GaussianMixture serialization --- gtsam/hybrid/GaussianMixture.h | 10 ++++ gtsam/hybrid/GaussianMixtureFactor.h | 18 +++++++ .../hybrid/tests/testSerializationHybrid.cpp | 48 +++++++++++++++---- 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 240f79dcd..ba84b5ade 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture */ GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + ar &BOOST_SERIALIZATION_NVP(conditionals_); + } }; /// Return the DiscreteKey vector as a set. diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 9138d6b30..01de2f0f7 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { bool operator==(const FactorAndConstant &other) const { return factor == other.factor && constant == other.constant; } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_NVP(factor); + ar &BOOST_SERIALIZATION_NVP(constant); + } }; /// typedef for Decision Tree of Gaussian factors and log-constant. @@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { return sum; } /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(factors_); + } }; // traits diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp index 64f0ce579..5337938dd 100644 --- a/gtsam/hybrid/tests/testSerializationHybrid.cpp +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -17,10 +17,12 @@ */ #include +#include #include #include #include #include +#include // Include for test suite #include @@ -29,26 +31,37 @@ using namespace std; using namespace gtsam; using symbol_shorthand::M; using symbol_shorthand::X; +using symbol_shorthand::Z; using namespace serializationTestHelpers; BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); +BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); using ADT = AlgebraicDecisionTree; BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); -BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf") -BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice") +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf"); +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice"); -BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor, "gtsam_GaussianMixtureFactor") -BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor::Factors, "gtsam_GaussianMixtureFactor_Factors") -BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor::Factors::Leaf, "gtsam_GaussianMixtureFactor_Factors_Leaf") -BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor::Factors::Choice, "gtsam_GaussianMixtureFactor_Factors_Choice") +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors, + "gtsam_GaussianMixtureFactor_Factors"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf, + "gtsam_GaussianMixtureFactor_Factors_Leaf"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice, + "gtsam_GaussianMixtureFactor_Factors_Choice"); -// BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixture, "gtsam_GaussianMixture") -// BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixture::Conditionals, -// "gtsam_GaussianMixture_Conditionals") +BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals, + "gtsam_GaussianMixture_Conditionals"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf, + "gtsam_GaussianMixture_Conditionals_Leaf"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice, + "gtsam_GaussianMixture_Conditionals_Choice"); +// Needed since GaussianConditional::FromMeanAndStddev uses it +BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); /* ****************************************************************************/ // Test HybridGaussianFactor serialization. @@ -92,6 +105,23 @@ TEST(HybridSerialization, GaussianMixtureFactor) { EXPECT(equalsBinary(factor)); } +/* ****************************************************************************/ +// Test GaussianMixture serialization. +TEST(HybridSerialization, GaussianMixture) { + const DiscreteKey mode(M(0), 2); + Matrix1 I = Matrix1::Identity(); + const auto conditional0 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); + const auto conditional1 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3)); + const GaussianMixture gm({Z(0)}, {X(0)}, {mode}, + {conditional0, conditional1}); + + EXPECT(equalsObj(gm)); + EXPECT(equalsXML(gm)); + EXPECT(equalsBinary(gm)); +} + /* ************************************************************************* */ int main() { TestResult tr;