Merge pull request #1368 from borglab/hybrid/serialization
Fixes https://github.com/borglab/gtsam/issues/1366release/4.3a0
commit
b62f397085
|
|
@ -64,6 +64,9 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
size_t nrAssignments_;
|
size_t nrAssignments_;
|
||||||
|
|
||||||
|
/// Default constructor for serialization.
|
||||||
|
Leaf() {}
|
||||||
|
|
||||||
/// Constructor from constant
|
/// Constructor from constant
|
||||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||||
|
|
@ -154,6 +157,18 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
using Base = DecisionTree<L, Y>::Node;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||||
|
}
|
||||||
}; // Leaf
|
}; // Leaf
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -177,6 +192,9 @@ namespace gtsam {
|
||||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/// Default constructor for serialization.
|
||||||
|
Choice() {}
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
|
@ -428,6 +446,19 @@ namespace gtsam {
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using Base = DecisionTree<L, Y>::Node;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(label_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(branches_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(allSame_);
|
||||||
|
}
|
||||||
}; // Choice
|
}; // Choice
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
|
||||||
|
#include <boost/serialization/nvp.hpp>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
@ -113,6 +115,12 @@ namespace gtsam {
|
||||||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||||
virtual Ptr choose(const L& label, size_t index) const = 0;
|
virtual Ptr choose(const L& label, size_t index) const = 0;
|
||||||
virtual bool isLeaf() const = 0;
|
virtual bool isLeaf() const = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
|
||||||
};
|
};
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
|
|
@ -364,8 +372,19 @@ namespace gtsam {
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(root_);
|
||||||
|
}
|
||||||
}; // DecisionTree
|
}; // DecisionTree
|
||||||
|
|
||||||
|
template <class L, class Y>
|
||||||
|
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
|
||||||
/// Apply unary operator `op` to DecisionTree `f`.
|
/// Apply unary operator `op` to DecisionTree `f`.
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,16 @@ namespace gtsam {
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// Internal version of choose
|
/// Internal version of choose
|
||||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||||
bool forceComplete) const;
|
bool forceComplete) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// DiscreteConditional
|
// DiscreteConditional
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,11 @@
|
||||||
// #define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
// #define GTSAM_DT_NO_PRUNING
|
// #define GTSAM_DT_NO_PRUNING
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,14 @@
|
||||||
* @date Feb 14, 2011
|
* @date Feb 14, 2011
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
|
@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
|
||||||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||||
DiscreteConditional prior(B % "1/2");
|
DiscreteConditional prior(B % "1/2");
|
||||||
DiscreteConditional pAB = prior * conditional;
|
DiscreteConditional pAB = prior * conditional;
|
||||||
GTSAM_PRINT(pAB);
|
|
||||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* testSerializtionDiscrete.cpp
|
||||||
|
*
|
||||||
|
* @date January 2023
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
using Tree = gtsam::DecisionTree<string, int>;
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||||
|
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test DecisionTree serialization.
|
||||||
|
TEST(DiscreteSerialization, DecisionTree) {
|
||||||
|
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
|
||||||
|
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
// Object roundtrip
|
||||||
|
Tree outputObj = create<Tree>();
|
||||||
|
roundtrip<Tree>(tree, outputObj);
|
||||||
|
EXPECT(tree.equals(outputObj));
|
||||||
|
|
||||||
|
// XML roundtrip
|
||||||
|
Tree outputXml = create<Tree>();
|
||||||
|
roundtripXML<Tree>(tree, outputXml);
|
||||||
|
EXPECT(tree.equals(outputXml));
|
||||||
|
|
||||||
|
// Binary roundtrip
|
||||||
|
Tree outputBinary = create<Tree>();
|
||||||
|
roundtripBinary<Tree>(tree, outputBinary);
|
||||||
|
EXPECT(tree.equals(outputBinary));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
|
||||||
|
TEST(DiscreteSerialization, DecisionTreeFactor) {
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||||
|
|
||||||
|
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
|
||||||
|
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
|
||||||
|
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
|
||||||
|
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
|
||||||
|
|
||||||
|
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||||
|
EXPECT(equalsObj<DecisionTreeFactor>(f));
|
||||||
|
EXPECT(equalsXML<DecisionTreeFactor>(f));
|
||||||
|
EXPECT(equalsBinary<DecisionTreeFactor>(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check serialization for DiscreteConditional & DiscreteDistribution
|
||||||
|
TEST(DiscreteSerialization, DiscreteConditional) {
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
DiscreteKey A(Symbol('x', 1), 3);
|
||||||
|
DiscreteConditional conditional(A % "1/2/2");
|
||||||
|
|
||||||
|
EXPECT(equalsObj<DiscreteConditional>(conditional));
|
||||||
|
EXPECT(equalsXML<DiscreteConditional>(conditional));
|
||||||
|
EXPECT(equalsBinary<DiscreteConditional>(conditional));
|
||||||
|
|
||||||
|
DiscreteDistribution P(A % "3/2/1");
|
||||||
|
EXPECT(equalsObj<DiscreteDistribution>(P));
|
||||||
|
EXPECT(equalsXML<DiscreteDistribution>(P));
|
||||||
|
EXPECT(equalsBinary<DiscreteDistribution>(P));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Return the DiscreteKey vector as a set.
|
/// Return the DiscreteKey vector as a set.
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
bool operator==(const FactorAndConstant &other) const {
|
bool operator==(const FactorAndConstant &other) const {
|
||||||
return factor == other.factor && constant == other.constant;
|
return factor == other.factor && constant == other.constant;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(factor);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(constant);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||||
|
|
@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(factors_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,6 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
auto other = e->asDiscrete();
|
auto other = e->asDiscrete();
|
||||||
return other != nullptr && dc->equals(*other, tol);
|
return other != nullptr && dc->equals(*other, tol);
|
||||||
}
|
}
|
||||||
return inner_->equals(*(e->inner_), tol);
|
|
||||||
|
|
||||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
: !(e->inner_);
|
: !(e->inner_);
|
||||||
|
|
|
||||||
|
|
@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
ar& BOOST_SERIALIZATION_NVP(inner_);
|
ar& BOOST_SERIALIZATION_NVP(inner_);
|
||||||
|
|
||||||
|
// register the various casts based on the type of inner_
|
||||||
|
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
|
||||||
|
if (isDiscrete()) {
|
||||||
|
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
|
||||||
|
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||||
|
} else if (isContinuous()) {
|
||||||
|
boost::serialization::void_cast_register<GaussianConditional, Factor>(
|
||||||
|
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||||
|
} else {
|
||||||
|
boost::serialization::void_cast_register<GaussianMixture, Factor>(
|
||||||
|
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // HybridConditional
|
}; // HybridConditional
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
if (e == nullptr) return false;
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
if (!Base::equals(*e, tol)) return false;
|
||||||
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
|
: !(e->inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor - for serialization.
|
||||||
|
HybridDiscreteFactor() = default;
|
||||||
|
|
||||||
// Implicit conversion from a shared ptr of DF
|
// Implicit conversion from a shared ptr of DF
|
||||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
||||||
|
|
||||||
|
|
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
||||||
/// Return the error of the underlying Discrete Factor.
|
/// Return the error of the underlying Discrete Factor.
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&other);
|
const This *e = dynamic_cast<const This *>(&other);
|
||||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
if (e == nullptr) return false;
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
if (!Base::equals(*e, tol)) return false;
|
||||||
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
|
: !(e->inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridGaussianFactor::print(const std::string &s,
|
void HybridGaussianFactor::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
HybridFactor::print(s, formatter);
|
HybridFactor::print(s, formatter);
|
||||||
|
if (inner_) {
|
||||||
inner_->print("\n", formatter);
|
inner_->print("\n", formatter);
|
||||||
|
} else {
|
||||||
|
std::cout << "\nGaussian: nullptr" << std::endl;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
using This = HybridGaussianFactor;
|
using This = HybridGaussianFactor;
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor - for serialization.
|
||||||
HybridGaussianFactor() = default;
|
HybridGaussianFactor() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -79,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
*/
|
*/
|
||||||
explicit HybridGaussianFactor(HessianFactor &&hf);
|
explicit HybridGaussianFactor(HessianFactor &&hf);
|
||||||
|
|
||||||
public:
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -101,6 +105,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
/// Return the error of the underlying Gaussian factor.
|
/// Return the error of the underlying Gaussian factor.
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,6 @@
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
|
@ -31,7 +30,6 @@
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace gtsam::serializationTestHelpers;
|
|
||||||
|
|
||||||
using noiseModel::Isotropic;
|
using noiseModel::Isotropic;
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
|
|
@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
discrete_conditional_tree->apply(checker);
|
discrete_conditional_tree->apply(checker);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
|
||||||
// Test HybridBayesNet serialization.
|
|
||||||
TEST(HybridBayesNet, Serialization) {
|
|
||||||
Switching s(4);
|
|
||||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
|
||||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
|
||||||
|
|
||||||
// TODO(Varun) Serialization of inner factor doesn't work. Requires
|
|
||||||
// serialization support for all hybrid factors.
|
|
||||||
// EXPECT(equalsObj<HybridBayesNet>(hbn));
|
|
||||||
// EXPECT(equalsXML<HybridBayesNet>(hbn));
|
|
||||||
// EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test HybridBayesNet sampling.
|
// Test HybridBayesNet sampling.
|
||||||
TEST(HybridBayesNet, Sampling) {
|
TEST(HybridBayesNet, Sampling) {
|
||||||
|
|
|
||||||
|
|
@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) {
|
||||||
EXPECT(assert_equal(expected_gbt, gbt));
|
EXPECT(assert_equal(expected_gbt, gbt));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
|
||||||
// Test HybridBayesTree serialization.
|
|
||||||
TEST(HybridBayesTree, Serialization) {
|
|
||||||
Switching s(4);
|
|
||||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
|
||||||
HybridBayesTree hbt =
|
|
||||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
|
||||||
|
|
||||||
using namespace gtsam::serializationTestHelpers;
|
|
||||||
// TODO(Varun) Serialization of inner factor doesn't work. Requires
|
|
||||||
// serialization support for all hybrid factors.
|
|
||||||
// EXPECT(equalsObj<HybridBayesTree>(hbt));
|
|
||||||
// EXPECT(equalsXML<HybridBayesTree>(hbt));
|
|
||||||
// EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,179 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testSerializationHybrid.cpp
|
||||||
|
* @brief Unit tests for hybrid serialization
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @date January 2023
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
|
#include "Switching.h"
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||||
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
|
||||||
|
"gtsam_GaussianMixtureFactor_Factors");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf,
|
||||||
|
"gtsam_GaussianMixtureFactor_Factors_Leaf");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice,
|
||||||
|
"gtsam_GaussianMixtureFactor_Factors_Choice");
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals,
|
||||||
|
"gtsam_GaussianMixture_Conditionals");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf,
|
||||||
|
"gtsam_GaussianMixture_Conditionals_Leaf");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
|
||||||
|
"gtsam_GaussianMixture_Conditionals_Choice");
|
||||||
|
// Needed since GaussianConditional::FromMeanAndStddev uses it
|
||||||
|
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridGaussianFactor serialization.
|
||||||
|
TEST(HybridSerialization, HybridGaussianFactor) {
|
||||||
|
const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridGaussianFactor>(factor));
|
||||||
|
EXPECT(equalsXML<HybridGaussianFactor>(factor));
|
||||||
|
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridDiscreteFactor serialization.
|
||||||
|
TEST(HybridSerialization, HybridDiscreteFactor) {
|
||||||
|
DiscreteKeys discreteKeys{{M(0), 2}};
|
||||||
|
const HybridDiscreteFactor factor(
|
||||||
|
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
|
||||||
|
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
|
||||||
|
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test GaussianMixtureFactor serialization.
|
||||||
|
TEST(HybridSerialization, GaussianMixtureFactor) {
|
||||||
|
KeyVector continuousKeys{X(0)};
|
||||||
|
DiscreteKeys discreteKeys{{M(0), 2}};
|
||||||
|
|
||||||
|
auto A = Matrix::Zero(2, 1);
|
||||||
|
auto b0 = Matrix::Zero(2, 1);
|
||||||
|
auto b1 = Matrix::Ones(2, 1);
|
||||||
|
auto f0 = boost::make_shared<JacobianFactor>(X(0), A, b0);
|
||||||
|
auto f1 = boost::make_shared<JacobianFactor>(X(0), A, b1);
|
||||||
|
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
|
||||||
|
|
||||||
|
const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors);
|
||||||
|
|
||||||
|
EXPECT(equalsObj<GaussianMixtureFactor>(factor));
|
||||||
|
EXPECT(equalsXML<GaussianMixtureFactor>(factor));
|
||||||
|
EXPECT(equalsBinary<GaussianMixtureFactor>(factor));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridConditional serialization.
|
||||||
|
TEST(HybridSerialization, HybridConditional) {
|
||||||
|
const DiscreteKey mode(M(0), 2);
|
||||||
|
Matrix1 I = Matrix1::Identity();
|
||||||
|
const auto conditional = boost::make_shared<GaussianConditional>(
|
||||||
|
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
|
||||||
|
const HybridConditional hc(conditional);
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridConditional>(hc));
|
||||||
|
EXPECT(equalsXML<HybridConditional>(hc));
|
||||||
|
EXPECT(equalsBinary<HybridConditional>(hc));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test GaussianMixture serialization.
|
||||||
|
TEST(HybridSerialization, GaussianMixture) {
|
||||||
|
const DiscreteKey mode(M(0), 2);
|
||||||
|
Matrix1 I = Matrix1::Identity();
|
||||||
|
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||||
|
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
|
||||||
|
const auto conditional1 = boost::make_shared<GaussianConditional>(
|
||||||
|
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
|
||||||
|
const GaussianMixture gm({Z(0)}, {X(0)}, {mode},
|
||||||
|
{conditional0, conditional1});
|
||||||
|
|
||||||
|
EXPECT(equalsObj<GaussianMixture>(gm));
|
||||||
|
EXPECT(equalsXML<GaussianMixture>(gm));
|
||||||
|
EXPECT(equalsBinary<GaussianMixture>(gm));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridBayesNet serialization.
|
||||||
|
TEST(HybridSerialization, HybridBayesNet) {
|
||||||
|
Switching s(2);
|
||||||
|
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||||
|
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||||
|
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridBayesTree serialization.
|
||||||
|
TEST(HybridSerialization, HybridBayesTree) {
|
||||||
|
Switching s(2);
|
||||||
|
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesTree hbt =
|
||||||
|
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||||
|
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||||
|
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
Loading…
Reference in New Issue