diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index da7c1421e..021ca1361 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar& BOOST_SERIALIZATION_NVP(inner_); + + // register the various casts based on the type of inner_ + // https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting + if (isDiscrete()) { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } else if (isContinuous()) { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } else { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } } }; // HybridConditional diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 43eebcf88..ef552bd92 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,7 +18,6 @@ * @date December 2021 */ -#include #include #include #include @@ -31,7 +30,6 @@ using namespace std; using namespace gtsam; -using namespace gtsam::serializationTestHelpers; using noiseModel::Isotropic; using symbol_shorthand::M; @@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { discrete_conditional_tree->apply(checker); } -/* ****************************************************************************/ -// Test HybridBayesNet serialization. -TEST(HybridBayesNet, Serialization) { - Switching s(4); - Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); - - // TODO(Varun) Serialization of inner factor doesn't work. Requires - // serialization support for all hybrid factors. - // EXPECT(equalsObj(hbn)); - // EXPECT(equalsXML(hbn)); - // EXPECT(equalsBinary(hbn)); -} - /* ****************************************************************************/ // Test HybridBayesNet sampling. TEST(HybridBayesNet, Sampling) { diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 1e6510383..b957a67d0 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) { EXPECT(assert_equal(expected_gbt, gbt)); } -/* ****************************************************************************/ -// Test HybridBayesTree serialization. -TEST(HybridBayesTree, Serialization) { - Switching s(4); - Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesTree hbt = - *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); - - using namespace gtsam::serializationTestHelpers; - // TODO(Varun) Serialization of inner factor doesn't work. Requires - // serialization support for all hybrid factors. - // EXPECT(equalsObj(hbt)); - // EXPECT(equalsXML(hbt)); - // EXPECT(equalsBinary(hbt)); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp index 9597fe8f0..941a1cdb3 100644 --- a/gtsam/hybrid/tests/testSerializationHybrid.cpp +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -17,14 +17,19 @@ */ #include +#include #include #include +#include +#include #include #include #include #include #include +#include "Switching.h" + // Include for test suite #include @@ -36,15 +41,17 @@ using symbol_shorthand::Z; using namespace serializationTestHelpers; +BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor"); BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); +BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional"); BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); using ADT = AlgebraicDecisionTree; BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); -BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf"); -BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice"); +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf"); +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor"); BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors, @@ -64,6 +71,8 @@ BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice, // Needed since GaussianConditional::FromMeanAndStddev uses it BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); +BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet"); + /* ****************************************************************************/ // Test HybridGaussianFactor serialization. TEST(HybridSerialization, HybridGaussianFactor) { @@ -137,6 +146,31 @@ TEST(HybridSerialization, GaussianMixture) { EXPECT(equalsBinary(gm)); } +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridSerialization, HybridBayesNet) { + Switching s(2); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + +/* ****************************************************************************/ +// Test HybridBayesTree serialization. +TEST(HybridSerialization, HybridBayesTree) { + Switching s(2); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree hbt = + *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); + + EXPECT(equalsObj(hbt)); + EXPECT(equalsXML(hbt)); + EXPECT(equalsBinary(hbt)); +} + /* ************************************************************************* */ int main() { TestResult tr;