From 2653c2f8fb075aa0e727bb6eac347a77eed4e096 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 3 Jan 2023 16:34:04 -0500 Subject: [PATCH] serialization support for DecisionTreeFactor --- gtsam/discrete/DecisionTreeFactor.h | 10 ++++++++ .../discrete/tests/testDecisionTreeFactor.cpp | 25 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index bb9ddbd96..f1df7ae03 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -231,6 +231,16 @@ namespace gtsam { const Names& names = {}) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); + ar& BOOST_SERIALIZATION_NVP(cardinalities_); + } }; // traits diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 1829db034..d4ca9faa6 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -220,6 +221,30 @@ TEST(DecisionTreeFactor, htmlWithValueFormatter) { EXPECT(actual == expected); } +/* ************************************************************************* */ +BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); +using ADT = AlgebraicDecisionTree; +BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf") +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice") + +// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor +TEST(DecisionTreeFactor, Serialization) { + using namespace serializationTestHelpers; + + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + + DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8"); + EXPECT(equalsObj(tree)); + EXPECT(equalsXML(tree)); + EXPECT(equalsBinary(tree)); + + DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + EXPECT(equalsObj(f)); + EXPECT(equalsXML(f)); + EXPECT(equalsBinary(f)); +} + /* ************************************************************************* */ int main() { TestResult tr;