From 6fcc0870308a1dc0c0517b311250011cff093c01 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 3 Jan 2023 18:15:59 -0500 Subject: [PATCH] serialize DiscreteConditional --- gtsam/discrete/DiscreteConditional.h | 9 ++++++++ .../tests/testDiscreteConditional.cpp | 23 ++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 6a286633d..b68953eb5 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional /// Internal version of choose DiscreteConditional::ADT choose(const DiscreteValues& given, bool forceComplete) const; + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } }; // DiscreteConditional diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 9098f7a1d..99ea138b1 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -17,13 +17,14 @@ * @date Feb 14, 2011 */ -#include - #include +#include #include #include #include +#include + using namespace std; using namespace gtsam; @@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) { DiscreteConditional conditional(A | B = "2/2 3/1"); DiscreteConditional prior(B % "1/2"); DiscreteConditional pAB = prior * conditional; - GTSAM_PRINT(pAB); // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 DiscreteConditional actualA = pAB.marginal(A.first); @@ -368,6 +368,23 @@ TEST(DiscreteConditional, html) { EXPECT(actual == expected); } +/* ************************************************************************* */ +using ADT = AlgebraicDecisionTree; +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_ADT_Leaf") +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_ADT_Choice") + +// Check serialization for DiscreteConditional +TEST(DiscreteConditional, Serialization) { + using namespace serializationTestHelpers; + + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); + + EXPECT(equalsObj(conditional)); + EXPECT(equalsXML(conditional)); + EXPECT(equalsBinary(conditional)); +} + /* ************************************************************************* */ int main() { TestResult tr;