/* ---------------------------------------------------------------------------- * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * See LICENSE for the license information * -------------------------------------------------------------------------- */ /* * testSerializtionDiscrete.cpp * * @date January 2023 * @author Varun Agrawal */ #include #include #include #include #include #include using namespace std; using namespace gtsam; using Tree = gtsam::DecisionTree; BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt") BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf") BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice") BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); BOOST_CLASS_EXPORT_GUID(TableFactor, "gtsam_TableFactor"); using ADT = AlgebraicDecisionTree; BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf") BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") /* ****************************************************************************/ // Test DecisionTree serialization. TEST(DiscreteSerialization, DecisionTree) { Tree tree({{"A", 2}}, std::vector{1, 2}); using namespace serializationTestHelpers; // Object roundtrip Tree outputObj = create(); roundtrip(tree, outputObj); EXPECT(tree.equals(outputObj)); // XML roundtrip Tree outputXml = create(); roundtripXML(tree, outputXml); EXPECT(tree.equals(outputXml)); // Binary roundtrip Tree outputBinary = create(); roundtripBinary(tree, outputBinary); EXPECT(tree.equals(outputBinary)); } /* ************************************************************************* */ // Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor TEST(DiscreteSerialization, DecisionTreeFactor) { using namespace serializationTestHelpers; DiscreteKey A(1, 2), B(2, 2), C(3, 2); DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8"); EXPECT(equalsObj(tree)); EXPECT(equalsXML(tree)); EXPECT(equalsBinary(tree)); DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); EXPECT(equalsObj(f)); EXPECT(equalsXML(f)); EXPECT(equalsBinary(f)); } /* ************************************************************************* */ // Check serialization for TableFactor TEST(DiscreteSerialization, TableFactor) { using namespace serializationTestHelpers; DiscreteKey A(Symbol('x', 1), 3); TableFactor tf(A, "1 2 2"); EXPECT(equalsObj(tf)); EXPECT(equalsXML(tf)); EXPECT(equalsBinary(tf)); } /* ************************************************************************* */ // Check serialization for DiscreteConditional & DiscreteDistribution TEST(DiscreteSerialization, DiscreteConditional) { using namespace serializationTestHelpers; DiscreteKey A(Symbol('x', 1), 3); DiscreteConditional conditional(A % "1/2/2"); EXPECT(equalsObj(conditional)); EXPECT(equalsXML(conditional)); EXPECT(equalsBinary(conditional)); DiscreteDistribution P(A % "3/2/1"); EXPECT(equalsObj(P)); EXPECT(equalsXML(P)); EXPECT(equalsBinary(P)); } /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr); } /* ************************************************************************* */