Merge pull request #1368 from borglab/hybrid/serialization
Fixes https://github.com/borglab/gtsam/issues/1366release/4.3a0
commit
b62f397085
|
|
@ -64,6 +64,9 @@ namespace gtsam {
|
|||
*/
|
||||
size_t nrAssignments_;
|
||||
|
||||
/// Default constructor for serialization.
|
||||
Leaf() {}
|
||||
|
||||
/// Constructor from constant
|
||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||
|
|
@ -154,6 +157,18 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
|
||||
private:
|
||||
using Base = DecisionTree<L, Y>::Node;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||
}
|
||||
}; // Leaf
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
@ -177,6 +192,9 @@ namespace gtsam {
|
|||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
||||
public:
|
||||
/// Default constructor for serialization.
|
||||
Choice() {}
|
||||
|
||||
~Choice() override {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||
|
|
@ -428,6 +446,19 @@ namespace gtsam {
|
|||
r->push_back(branch->choose(label, index));
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
private:
|
||||
using Base = DecisionTree<L, Y>::Node;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(label_);
|
||||
ar& BOOST_SERIALIZATION_NVP(branches_);
|
||||
ar& BOOST_SERIALIZATION_NVP(allSame_);
|
||||
}
|
||||
}; // Choice
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -19,9 +19,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
|
@ -113,6 +115,12 @@ namespace gtsam {
|
|||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||
virtual Ptr choose(const L& label, size_t index) const = 0;
|
||||
virtual bool isLeaf() const = 0;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
|
||||
};
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
|
||||
|
|
@ -364,8 +372,19 @@ namespace gtsam {
|
|||
compose(Iterator begin, Iterator end, const L& label) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_NVP(root_);
|
||||
}
|
||||
}; // DecisionTree
|
||||
|
||||
template <class L, class Y>
|
||||
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
|
||||
|
||||
/** free versions of apply */
|
||||
|
||||
/// Apply unary operator `op` to DecisionTree `f`.
|
||||
|
|
|
|||
|
|
@ -231,6 +231,16 @@ namespace gtsam {
|
|||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// Internal version of choose
|
||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||
bool forceComplete) const;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
}
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
|
|
|
|||
|
|
@ -20,12 +20,11 @@
|
|||
// #define DT_DEBUG_MEMORY
|
||||
// #define GTSAM_DT_NO_PRUNING
|
||||
#define DISABLE_DOT
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@
|
|||
* @date Feb 14, 2011
|
||||
*/
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
|
|
@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
|
|||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
GTSAM_PRINT(pAB);
|
||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testSerializtionDiscrete.cpp
|
||||
*
|
||||
* @date January 2023
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
using Tree = gtsam::DecisionTree<string, int>;
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
|
||||
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
|
||||
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test DecisionTree serialization.
|
||||
TEST(DiscreteSerialization, DecisionTree) {
|
||||
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
|
||||
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
// Object roundtrip
|
||||
Tree outputObj = create<Tree>();
|
||||
roundtrip<Tree>(tree, outputObj);
|
||||
EXPECT(tree.equals(outputObj));
|
||||
|
||||
// XML roundtrip
|
||||
Tree outputXml = create<Tree>();
|
||||
roundtripXML<Tree>(tree, outputXml);
|
||||
EXPECT(tree.equals(outputXml));
|
||||
|
||||
// Binary roundtrip
|
||||
Tree outputBinary = create<Tree>();
|
||||
roundtripBinary<Tree>(tree, outputBinary);
|
||||
EXPECT(tree.equals(outputBinary));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
|
||||
TEST(DiscreteSerialization, DecisionTreeFactor) {
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
|
||||
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
|
||||
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
|
||||
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
|
||||
|
||||
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
EXPECT(equalsObj<DecisionTreeFactor>(f));
|
||||
EXPECT(equalsXML<DecisionTreeFactor>(f));
|
||||
EXPECT(equalsBinary<DecisionTreeFactor>(f));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check serialization for DiscreteConditional & DiscreteDistribution
|
||||
TEST(DiscreteSerialization, DiscreteConditional) {
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
DiscreteKey A(Symbol('x', 1), 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
|
||||
EXPECT(equalsObj<DiscreteConditional>(conditional));
|
||||
EXPECT(equalsXML<DiscreteConditional>(conditional));
|
||||
EXPECT(equalsBinary<DiscreteConditional>(conditional));
|
||||
|
||||
DiscreteDistribution P(A % "3/2/1");
|
||||
EXPECT(equalsObj<DiscreteDistribution>(P));
|
||||
EXPECT(equalsXML<DiscreteDistribution>(P));
|
||||
EXPECT(equalsBinary<DiscreteDistribution>(P));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture
|
|||
*/
|
||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
||||
}
|
||||
};
|
||||
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
|
|
|
|||
|
|
@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
bool operator==(const FactorAndConstant &other) const {
|
||||
return factor == other.factor && constant == other.constant;
|
||||
}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_NVP(factor);
|
||||
ar &BOOST_SERIALIZATION_NVP(constant);
|
||||
}
|
||||
};
|
||||
|
||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||
|
|
@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
return sum;
|
||||
}
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(factors_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -116,7 +116,6 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
|||
auto other = e->asDiscrete();
|
||||
return other != nullptr && dc->equals(*other, tol);
|
||||
}
|
||||
return inner_->equals(*(e->inner_), tol);
|
||||
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
|
|
|
|||
|
|
@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
|
|||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
ar& BOOST_SERIALIZATION_NVP(inner_);
|
||||
|
||||
// register the various casts based on the type of inner_
|
||||
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
|
||||
if (isDiscrete()) {
|
||||
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
|
||||
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||
} else if (isContinuous()) {
|
||||
boost::serialization::void_cast_register<GaussianConditional, Factor>(
|
||||
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||
} else {
|
||||
boost::serialization::void_cast_register<GaussianMixture, Factor>(
|
||||
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
|
||||
}
|
||||
}
|
||||
|
||||
}; // HybridConditional
|
||||
|
|
|
|||
|
|
@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
|||
/* ************************************************************************ */
|
||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&lf);
|
||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
||||
return e != nullptr && Base::equals(*e, tol);
|
||||
if (e == nullptr) return false;
|
||||
if (!Base::equals(*e, tol)) return false;
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor - for serialization.
|
||||
HybridDiscreteFactor() = default;
|
||||
|
||||
// Implicit conversion from a shared ptr of DF
|
||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
||||
|
||||
|
|
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
|
|||
/* ************************************************************************* */
|
||||
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&other);
|
||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
||||
return e != nullptr && Base::equals(*e, tol);
|
||||
if (e == nullptr) return false;
|
||||
if (!Base::equals(*e, tol)) return false;
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianFactor::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
HybridFactor::print(s, formatter);
|
||||
inner_->print("\n", formatter);
|
||||
if (inner_) {
|
||||
inner_->print("\n", formatter);
|
||||
} else {
|
||||
std::cout << "\nGaussian: nullptr" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
using This = HybridGaussianFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor - for serialization.
|
||||
HybridGaussianFactor() = default;
|
||||
|
||||
/**
|
||||
|
|
@ -79,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
*/
|
||||
explicit HybridGaussianFactor(HessianFactor &&hf);
|
||||
|
||||
public:
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
|
|
@ -101,6 +105,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
/// Return the error of the underlying Gaussian factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@
|
|||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
|
|
@ -31,7 +30,6 @@
|
|||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
|
|
@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
discrete_conditional_tree->apply(checker);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
// TODO(Varun) Serialization of inner factor doesn't work. Requires
|
||||
// serialization support for all hybrid factors.
|
||||
// EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
// EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
// EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet sampling.
|
||||
TEST(HybridBayesNet, Sampling) {
|
||||
|
|
|
|||
|
|
@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) {
|
|||
EXPECT(assert_equal(expected_gbt, gbt));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridBayesTree, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
// TODO(Varun) Serialization of inner factor doesn't work. Requires
|
||||
// serialization support for all hybrid factors.
|
||||
// EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
// EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
// EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,179 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testSerializationHybrid.cpp
|
||||
* @brief Unit tests for hybrid serialization
|
||||
* @author Varun Agrawal
|
||||
* @date January 2023
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
#include "Switching.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
using symbol_shorthand::Z;
|
||||
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
|
||||
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
|
||||
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
|
||||
"gtsam_GaussianMixtureFactor_Factors");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf,
|
||||
"gtsam_GaussianMixtureFactor_Factors_Leaf");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice,
|
||||
"gtsam_GaussianMixtureFactor_Factors_Choice");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals,
|
||||
"gtsam_GaussianMixture_Conditionals");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf,
|
||||
"gtsam_GaussianMixture_Conditionals_Leaf");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
|
||||
"gtsam_GaussianMixture_Conditionals_Choice");
|
||||
// Needed since GaussianConditional::FromMeanAndStddev uses it
|
||||
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridGaussianFactor serialization.
|
||||
TEST(HybridSerialization, HybridGaussianFactor) {
|
||||
const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||
|
||||
EXPECT(equalsObj<HybridGaussianFactor>(factor));
|
||||
EXPECT(equalsXML<HybridGaussianFactor>(factor));
|
||||
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridDiscreteFactor serialization.
|
||||
TEST(HybridSerialization, HybridDiscreteFactor) {
|
||||
DiscreteKeys discreteKeys{{M(0), 2}};
|
||||
const HybridDiscreteFactor factor(
|
||||
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
|
||||
|
||||
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
|
||||
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
|
||||
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test GaussianMixtureFactor serialization.
|
||||
TEST(HybridSerialization, GaussianMixtureFactor) {
|
||||
KeyVector continuousKeys{X(0)};
|
||||
DiscreteKeys discreteKeys{{M(0), 2}};
|
||||
|
||||
auto A = Matrix::Zero(2, 1);
|
||||
auto b0 = Matrix::Zero(2, 1);
|
||||
auto b1 = Matrix::Ones(2, 1);
|
||||
auto f0 = boost::make_shared<JacobianFactor>(X(0), A, b0);
|
||||
auto f1 = boost::make_shared<JacobianFactor>(X(0), A, b1);
|
||||
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors);
|
||||
|
||||
EXPECT(equalsObj<GaussianMixtureFactor>(factor));
|
||||
EXPECT(equalsXML<GaussianMixtureFactor>(factor));
|
||||
EXPECT(equalsBinary<GaussianMixtureFactor>(factor));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridConditional serialization.
|
||||
TEST(HybridSerialization, HybridConditional) {
|
||||
const DiscreteKey mode(M(0), 2);
|
||||
Matrix1 I = Matrix1::Identity();
|
||||
const auto conditional = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
|
||||
const HybridConditional hc(conditional);
|
||||
|
||||
EXPECT(equalsObj<HybridConditional>(hc));
|
||||
EXPECT(equalsXML<HybridConditional>(hc));
|
||||
EXPECT(equalsBinary<HybridConditional>(hc));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test GaussianMixture serialization.
|
||||
TEST(HybridSerialization, GaussianMixture) {
|
||||
const DiscreteKey mode(M(0), 2);
|
||||
Matrix1 I = Matrix1::Identity();
|
||||
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
|
||||
const auto conditional1 = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
|
||||
const GaussianMixture gm({Z(0)}, {X(0)}, {mode},
|
||||
{conditional0, conditional1});
|
||||
|
||||
EXPECT(equalsObj<GaussianMixture>(gm));
|
||||
EXPECT(equalsXML<GaussianMixture>(gm));
|
||||
EXPECT(equalsBinary<GaussianMixture>(gm));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridSerialization, HybridBayesNet) {
|
||||
Switching s(2);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridSerialization, HybridBayesTree) {
|
||||
Switching s(2);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
Loading…
Reference in New Issue