Merge pull request #1368 from borglab/hybrid/serialization

Fixes https://github.com/borglab/gtsam/issues/1366
release/4.3a0
Varun Agrawal 2023-01-04 03:51:22 -05:00 committed by GitHub
commit b62f397085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 441 additions and 47 deletions

View File

@ -64,6 +64,9 @@ namespace gtsam {
*/
size_t nrAssignments_;
/// Default constructor for serialization.
Leaf() {}
/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
@ -154,6 +157,18 @@ namespace gtsam {
}
bool isLeaf() const override { return true; }
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf
/****************************************************************************/
@ -177,6 +192,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>;
public:
/// Default constructor for serialization.
Choice() {}
~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
@ -428,6 +446,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index));
return Unique(r);
}
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice
/****************************************************************************/

View File

@ -19,9 +19,11 @@
#pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>
#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp>
#include <functional>
#include <iostream>
@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
};
/** ------------------------ Node base class --------------------------- */
@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree
template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
/** free versions of apply */
/// Apply unary operator `op` to DecisionTree `f`.

View File

@ -231,6 +231,16 @@ namespace gtsam {
const Names& names = {}) const override;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
};
// traits

View File

@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
};
// DiscreteConditional

View File

@ -20,12 +20,11 @@
// #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
using namespace gtsam;

View File

@ -19,6 +19,7 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>

View File

@ -17,13 +17,14 @@
* @date Feb 14, 2011
*/
#include <boost/make_shared.hpp>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>
#include <boost/make_shared.hpp>
using namespace std;
using namespace gtsam;
@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);

View File

@ -0,0 +1,105 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testSerializtionDiscrete.cpp
*
* @date January 2023
* @author Varun Agrawal
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/Symbol.h>
using namespace std;
using namespace gtsam;
using Tree = gtsam::DecisionTree<string, int>;
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
/* ****************************************************************************/
// Test DecisionTree serialization.
TEST(DiscreteSerialization, DecisionTree) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
using namespace serializationTestHelpers;
// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));
// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));
// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}
/* ************************************************************************* */
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
TEST(DiscreteSerialization, DecisionTreeFactor) {
using namespace serializationTestHelpers;
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor>(f));
EXPECT(equalsXML<DecisionTreeFactor>(f));
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}
/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
using namespace serializationTestHelpers;
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));
DiscreteDistribution P(A % "3/2/1");
EXPECT(equalsObj<DiscreteDistribution>(P));
EXPECT(equalsXML<DiscreteDistribution>(P));
EXPECT(equalsBinary<DiscreteDistribution>(P));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
};
/// Return the DiscreteKey vector as a set.

View File

@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
return sum;
}
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
};
// traits

View File

@ -116,7 +116,6 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_->equals(*(e->inner_), tol);
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);

View File

@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);
// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
}
}; // HybridConditional

View File

@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
if (e == nullptr) return false;
if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */

View File

@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// @name Constructors
/// @{
/// Default constructor - for serialization.
HybridDiscreteFactor() = default;
// Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
};
// traits

View File

@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
/* ************************************************************************* */
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
if (e == nullptr) return false;
if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner_->print("\n", formatter);
if (inner_) {
inner_->print("\n", formatter);
} else {
std::cout << "\nGaussian: nullptr" << std::endl;
}
};
/* ************************************************************************ */

View File

@ -43,6 +43,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
using This = HybridGaussianFactor;
using shared_ptr = boost::shared_ptr<This>;
/// @name Constructors
/// @{
/// Default constructor - for serialization.
HybridGaussianFactor() = default;
/**
@ -79,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
explicit HybridGaussianFactor(HessianFactor &&hf);
public:
/// @}
/// @name Testable
/// @{
@ -101,6 +105,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Return the error of the underlying Gaussian factor.
double error(const HybridValues &values) const override;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
};
// traits

View File

@ -18,7 +18,6 @@
* @date December 2021
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
@ -31,7 +30,6 @@
using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
using noiseModel::Isotropic;
using symbol_shorthand::M;
@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
discrete_conditional_tree->apply(checker);
}
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
// TODO(Varun) Serialization of inner factor doesn't work. Requires
// serialization support for all hybrid factors.
// EXPECT(equalsObj<HybridBayesNet>(hbn));
// EXPECT(equalsXML<HybridBayesNet>(hbn));
// EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/
// Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) {

View File

@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) {
EXPECT(assert_equal(expected_gbt, gbt));
}
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
using namespace gtsam::serializationTestHelpers;
// TODO(Varun) Serialization of inner factor doesn't work. Requires
// serialization support for all hybrid factors.
// EXPECT(equalsObj<HybridBayesTree>(hbt));
// EXPECT(equalsXML<HybridBayesTree>(hbt));
// EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -0,0 +1,179 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testSerializationHybrid.cpp
* @brief Unit tests for hybrid serialization
* @author Varun Agrawal
* @date January 2023
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include "Switching.h"
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
using namespace serializationTestHelpers;
BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
"gtsam_GaussianMixtureFactor_Factors");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf,
"gtsam_GaussianMixtureFactor_Factors_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice,
"gtsam_GaussianMixtureFactor_Factors_Choice");
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals,
"gtsam_GaussianMixture_Conditionals");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf,
"gtsam_GaussianMixture_Conditionals_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
"gtsam_GaussianMixture_Conditionals_Choice");
// Needed since GaussianConditional::FromMeanAndStddev uses it
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");
/* ****************************************************************************/
// Test HybridGaussianFactor serialization.
TEST(HybridSerialization, HybridGaussianFactor) {
const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1));
EXPECT(equalsObj<HybridGaussianFactor>(factor));
EXPECT(equalsXML<HybridGaussianFactor>(factor));
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
}
/* ****************************************************************************/
// Test HybridDiscreteFactor serialization.
TEST(HybridSerialization, HybridDiscreteFactor) {
DiscreteKeys discreteKeys{{M(0), 2}};
const HybridDiscreteFactor factor(
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
}
/* ****************************************************************************/
// Test GaussianMixtureFactor serialization.
TEST(HybridSerialization, GaussianMixtureFactor) {
KeyVector continuousKeys{X(0)};
DiscreteKeys discreteKeys{{M(0), 2}};
auto A = Matrix::Zero(2, 1);
auto b0 = Matrix::Zero(2, 1);
auto b1 = Matrix::Ones(2, 1);
auto f0 = boost::make_shared<JacobianFactor>(X(0), A, b0);
auto f1 = boost::make_shared<JacobianFactor>(X(0), A, b1);
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors);
EXPECT(equalsObj<GaussianMixtureFactor>(factor));
EXPECT(equalsXML<GaussianMixtureFactor>(factor));
EXPECT(equalsBinary<GaussianMixtureFactor>(factor));
}
/* ****************************************************************************/
// Test HybridConditional serialization.
TEST(HybridSerialization, HybridConditional) {
const DiscreteKey mode(M(0), 2);
Matrix1 I = Matrix1::Identity();
const auto conditional = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
const HybridConditional hc(conditional);
EXPECT(equalsObj<HybridConditional>(hc));
EXPECT(equalsXML<HybridConditional>(hc));
EXPECT(equalsBinary<HybridConditional>(hc));
}
/* ****************************************************************************/
// Test GaussianMixture serialization.
TEST(HybridSerialization, GaussianMixture) {
const DiscreteKey mode(M(0), 2);
Matrix1 I = Matrix1::Identity();
const auto conditional0 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
const auto conditional1 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
const GaussianMixture gm({Z(0)}, {X(0)}, {mode},
{conditional0, conditional1});
EXPECT(equalsObj<GaussianMixture>(gm));
EXPECT(equalsXML<GaussianMixture>(gm));
EXPECT(equalsBinary<GaussianMixture>(gm));
}
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridSerialization, HybridBayesNet) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridSerialization, HybridBayesTree) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */