diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f72743206..d01c91840 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -64,6 +64,9 @@ namespace gtsam { */ size_t nrAssignments_; + /// Default constructor for serialization. + Leaf() {} + /// Constructor from constant Leaf(const Y& constant, size_t nrAssignments = 1) : constant_(constant), nrAssignments_(nrAssignments) {} @@ -154,6 +157,18 @@ namespace gtsam { } bool isLeaf() const override { return true; } + + private: + using Base = DecisionTree::Node; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(constant_); + ar& BOOST_SERIALIZATION_NVP(nrAssignments_); + } }; // Leaf /****************************************************************************/ @@ -177,6 +192,9 @@ namespace gtsam { using ChoicePtr = boost::shared_ptr; public: + /// Default constructor for serialization. + Choice() {} + ~Choice() override { #ifdef DT_DEBUG_MEMORY std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() @@ -428,6 +446,19 @@ namespace gtsam { r->push_back(branch->choose(label, index)); return Unique(r); } + + private: + using Base = DecisionTree::Node; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(label_); + ar& BOOST_SERIALIZATION_NVP(branches_); + ar& BOOST_SERIALIZATION_NVP(allSame_); + } }; // Choice /****************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 01a57637f..a8764a98f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,9 +19,11 @@ #pragma once +#include #include #include +#include #include #include #include @@ -113,6 +115,12 @@ namespace gtsam { virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr choose(const L& label, size_t index) const = 0; virtual bool isLeaf() const = 0; + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) {} }; /** ------------------------ Node base class --------------------------- */ @@ -364,8 +372,19 @@ namespace gtsam { compose(Iterator begin, Iterator end, const L& label) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(root_); + } }; // DecisionTree + template + struct traits> : public Testable> {}; + /** free versions of apply */ /// Apply unary operator `op` to DecisionTree `f`. diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index bb9ddbd96..f1df7ae03 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -231,6 +231,16 @@ namespace gtsam { const Names& names = {}) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); + ar& BOOST_SERIALIZATION_NVP(cardinalities_); + } }; // traits diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 6a286633d..b68953eb5 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional /// Internal version of choose DiscreteConditional::ADT choose(const DiscreteValues& given, bool forceComplete) const; + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } }; // DiscreteConditional diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 63b14b05e..46e6c3813 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -20,12 +20,11 @@ // #define DT_DEBUG_MEMORY // #define GTSAM_DT_NO_PRUNING #define DISABLE_DOT -#include - -#include -#include - #include +#include +#include +#include +#include using namespace std; using namespace gtsam; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 1829db034..869b3c630 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 9098f7a1d..3df4bf9e6 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -17,13 +17,14 @@ * @date Feb 14, 2011 */ -#include - #include +#include #include #include #include +#include + using namespace std; using namespace gtsam; @@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) { DiscreteConditional conditional(A | B = "2/2 3/1"); DiscreteConditional prior(B % "1/2"); DiscreteConditional pAB = prior * conditional; - GTSAM_PRINT(pAB); // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 DiscreteConditional actualA = pAB.marginal(A.first); diff --git a/gtsam/discrete/tests/testSerializationDiscrete.cpp b/gtsam/discrete/tests/testSerializationDiscrete.cpp new file mode 100644 index 000000000..df7df0b7e --- /dev/null +++ b/gtsam/discrete/tests/testSerializationDiscrete.cpp @@ -0,0 +1,105 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testSerializtionDiscrete.cpp + * + * @date January 2023 + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +using Tree = gtsam::DecisionTree; + +BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt") +BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf") +BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice") + +BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); + +using ADT = AlgebraicDecisionTree; +BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf") +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") + +/* ****************************************************************************/ +// Test DecisionTree serialization. +TEST(DiscreteSerialization, DecisionTree) { + Tree tree({{"A", 2}}, std::vector{1, 2}); + + using namespace serializationTestHelpers; + + // Object roundtrip + Tree outputObj = create(); + roundtrip(tree, outputObj); + EXPECT(tree.equals(outputObj)); + + // XML roundtrip + Tree outputXml = create(); + roundtripXML(tree, outputXml); + EXPECT(tree.equals(outputXml)); + + // Binary roundtrip + Tree outputBinary = create(); + roundtripBinary(tree, outputBinary); + EXPECT(tree.equals(outputBinary)); +} + +/* ************************************************************************* */ +// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor +TEST(DiscreteSerialization, DecisionTreeFactor) { + using namespace serializationTestHelpers; + + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + + DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8"); + EXPECT(equalsObj(tree)); + EXPECT(equalsXML(tree)); + EXPECT(equalsBinary(tree)); + + DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + EXPECT(equalsObj(f)); + EXPECT(equalsXML(f)); + EXPECT(equalsBinary(f)); +} + +/* ************************************************************************* */ +// Check serialization for DiscreteConditional & DiscreteDistribution +TEST(DiscreteSerialization, DiscreteConditional) { + using namespace serializationTestHelpers; + + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); + + EXPECT(equalsObj(conditional)); + EXPECT(equalsXML(conditional)); + EXPECT(equalsBinary(conditional)); + + DiscreteDistribution P(A % "3/2/1"); + EXPECT(equalsObj(P)); + EXPECT(equalsXML(P)); + EXPECT(equalsBinary(P)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 240f79dcd..ba84b5ade 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture */ GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + ar &BOOST_SERIALIZATION_NVP(conditionals_); + } }; /// Return the DiscreteKey vector as a set. diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 9138d6b30..01de2f0f7 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { bool operator==(const FactorAndConstant &other) const { return factor == other.factor && constant == other.constant; } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_NVP(factor); + ar &BOOST_SERIALIZATION_NVP(constant); + } }; /// typedef for Decision Tree of Gaussian factors and log-constant. @@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { return sum; } /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(factors_); + } }; // traits diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 8371a3365..0bfcfec4d 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -116,7 +116,6 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { auto other = e->asDiscrete(); return other != nullptr && dc->equals(*other, tol); } - return inner_->equals(*(e->inner_), tol); return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) : !(e->inner_); diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index da7c1421e..021ca1361 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar& BOOST_SERIALIZATION_NVP(inner_); + + // register the various casts based on the type of inner_ + // https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting + if (isDiscrete()) { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } else if (isContinuous()) { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } else { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } } }; // HybridConditional diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 605ea5738..dd8b867be 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) /* ************************************************************************ */ bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - // TODO(Varun) How to compare inner_ when they are abstract types? - return e != nullptr && Base::equals(*e, tol); + if (e == nullptr) return false; + if (!Base::equals(*e, tol)) return false; + return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) + : !(e->inner_); } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index 7ac97443a..7a43ab3a5 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { /// @name Constructors /// @{ + /// Default constructor - for serialization. + HybridDiscreteFactor() = default; + // Implicit conversion from a shared ptr of DF HybridDiscreteFactor(DiscreteFactor::shared_ptr other); @@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { /// Return the error of the underlying Discrete Factor. double error(const HybridValues &values) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(inner_); + } }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 5a89a04a8..4fe18bea7 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf) /* ************************************************************************* */ bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const { const This *e = dynamic_cast(&other); - // TODO(Varun) How to compare inner_ when they are abstract types? - return e != nullptr && Base::equals(*e, tol); + if (e == nullptr) return false; + if (!Base::equals(*e, tol)) return false; + return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) + : !(e->inner_); } /* ************************************************************************* */ void HybridGaussianFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); - inner_->print("\n", formatter); + if (inner_) { + inner_->print("\n", formatter); + } else { + std::cout << "\nGaussian: nullptr" << std::endl; + } }; /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 6d179d1c1..6bb022396 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -43,6 +43,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { using This = HybridGaussianFactor; using shared_ptr = boost::shared_ptr; + /// @name Constructors + /// @{ + + /// Default constructor - for serialization. HybridGaussianFactor() = default; /** @@ -79,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { */ explicit HybridGaussianFactor(HessianFactor &&hf); - public: + /// @} /// @name Testable /// @{ @@ -101,6 +105,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// Return the error of the underlying Gaussian factor. double error(const HybridValues &values) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(inner_); + } }; // traits diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 43eebcf88..ef552bd92 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,7 +18,6 @@ * @date December 2021 */ -#include #include #include #include @@ -31,7 +30,6 @@ using namespace std; using namespace gtsam; -using namespace gtsam::serializationTestHelpers; using noiseModel::Isotropic; using symbol_shorthand::M; @@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { discrete_conditional_tree->apply(checker); } -/* ****************************************************************************/ -// Test HybridBayesNet serialization. -TEST(HybridBayesNet, Serialization) { - Switching s(4); - Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); - - // TODO(Varun) Serialization of inner factor doesn't work. Requires - // serialization support for all hybrid factors. - // EXPECT(equalsObj(hbn)); - // EXPECT(equalsXML(hbn)); - // EXPECT(equalsBinary(hbn)); -} - /* ****************************************************************************/ // Test HybridBayesNet sampling. TEST(HybridBayesNet, Sampling) { diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 1e6510383..b957a67d0 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) { EXPECT(assert_equal(expected_gbt, gbt)); } -/* ****************************************************************************/ -// Test HybridBayesTree serialization. -TEST(HybridBayesTree, Serialization) { - Switching s(4); - Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesTree hbt = - *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); - - using namespace gtsam::serializationTestHelpers; - // TODO(Varun) Serialization of inner factor doesn't work. Requires - // serialization support for all hybrid factors. - // EXPECT(equalsObj(hbt)); - // EXPECT(equalsXML(hbt)); - // EXPECT(equalsBinary(hbt)); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp new file mode 100644 index 000000000..941a1cdb3 --- /dev/null +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -0,0 +1,179 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testSerializationHybrid.cpp + * @brief Unit tests for hybrid serialization + * @author Varun Agrawal + * @date January 2023 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Switching.h" + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using symbol_shorthand::M; +using symbol_shorthand::X; +using symbol_shorthand::Z; + +using namespace serializationTestHelpers; + +BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor"); +BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); +BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); +BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); +BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional"); + +BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); +using ADT = AlgebraicDecisionTree; +BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf"); +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") + +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors, + "gtsam_GaussianMixtureFactor_Factors"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf, + "gtsam_GaussianMixtureFactor_Factors_Leaf"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice, + "gtsam_GaussianMixtureFactor_Factors_Choice"); + +BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals, + "gtsam_GaussianMixture_Conditionals"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf, + "gtsam_GaussianMixture_Conditionals_Leaf"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice, + "gtsam_GaussianMixture_Conditionals_Choice"); +// Needed since GaussianConditional::FromMeanAndStddev uses it +BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); + +BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet"); + +/* ****************************************************************************/ +// Test HybridGaussianFactor serialization. +TEST(HybridSerialization, HybridGaussianFactor) { + const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1)); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ****************************************************************************/ +// Test HybridDiscreteFactor serialization. +TEST(HybridSerialization, HybridDiscreteFactor) { + DiscreteKeys discreteKeys{{M(0), 2}}; + const HybridDiscreteFactor factor( + DecisionTreeFactor(discreteKeys, std::vector{0.4, 0.6})); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ****************************************************************************/ +// Test GaussianMixtureFactor serialization. +TEST(HybridSerialization, GaussianMixtureFactor) { + KeyVector continuousKeys{X(0)}; + DiscreteKeys discreteKeys{{M(0), 2}}; + + auto A = Matrix::Zero(2, 1); + auto b0 = Matrix::Zero(2, 1); + auto b1 = Matrix::Ones(2, 1); + auto f0 = boost::make_shared(X(0), A, b0); + auto f1 = boost::make_shared(X(0), A, b1); + std::vector factors{f0, f1}; + + const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ****************************************************************************/ +// Test HybridConditional serialization. +TEST(HybridSerialization, HybridConditional) { + const DiscreteKey mode(M(0), 2); + Matrix1 I = Matrix1::Identity(); + const auto conditional = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); + const HybridConditional hc(conditional); + + EXPECT(equalsObj(hc)); + EXPECT(equalsXML(hc)); + EXPECT(equalsBinary(hc)); +} + +/* ****************************************************************************/ +// Test GaussianMixture serialization. +TEST(HybridSerialization, GaussianMixture) { + const DiscreteKey mode(M(0), 2); + Matrix1 I = Matrix1::Identity(); + const auto conditional0 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); + const auto conditional1 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3)); + const GaussianMixture gm({Z(0)}, {X(0)}, {mode}, + {conditional0, conditional1}); + + EXPECT(equalsObj(gm)); + EXPECT(equalsXML(gm)); + EXPECT(equalsBinary(gm)); +} + +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridSerialization, HybridBayesNet) { + Switching s(2); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + +/* ****************************************************************************/ +// Test HybridBayesTree serialization. +TEST(HybridSerialization, HybridBayesTree) { + Switching s(2); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree hbt = + *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); + + EXPECT(equalsObj(hbt)); + EXPECT(equalsXML(hbt)); + EXPECT(equalsBinary(hbt)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */