GaussianMixture serialization

release/4.3a0
Varun Agrawal 2023-01-03 17:37:09 -05:00
parent 3f2bff8e1d
commit 74142e4be1
3 changed files with 67 additions and 9 deletions

View File

@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
};
/// Return the DiscreteKey vector as a set.

View File

@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
return sum;
}
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
};
// traits

View File

@ -17,10 +17,12 @@
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
@ -29,26 +31,37 @@ using namespace std;
using namespace gtsam;
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
using namespace serializationTestHelpers;
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice")
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice");
BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor, "gtsam_GaussianMixtureFactor")
BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor::Factors, "gtsam_GaussianMixtureFactor_Factors")
BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor::Factors::Leaf, "gtsam_GaussianMixtureFactor_Factors_Leaf")
BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixtureFactor::Factors::Choice, "gtsam_GaussianMixtureFactor_Factors_Choice")
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
"gtsam_GaussianMixtureFactor_Factors");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf,
"gtsam_GaussianMixtureFactor_Factors_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice,
"gtsam_GaussianMixtureFactor_Factors_Choice");
// BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixture, "gtsam_GaussianMixture")
// BOOST_CLASS_EXPORT_GUID(gtsam::GaussianMixture::Conditionals,
// "gtsam_GaussianMixture_Conditionals")
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals,
"gtsam_GaussianMixture_Conditionals");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf,
"gtsam_GaussianMixture_Conditionals_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
"gtsam_GaussianMixture_Conditionals_Choice");
// Needed since GaussianConditional::FromMeanAndStddev uses it
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
/* ****************************************************************************/
// Test HybridGaussianFactor serialization.
@ -92,6 +105,23 @@ TEST(HybridSerialization, GaussianMixtureFactor) {
EXPECT(equalsBinary<GaussianMixtureFactor>(factor));
}
/* ****************************************************************************/
// Test GaussianMixture serialization.
TEST(HybridSerialization, GaussianMixture) {
const DiscreteKey mode(M(0), 2);
Matrix1 I = Matrix1::Identity();
const auto conditional0 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
const auto conditional1 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
const GaussianMixture gm({Z(0)}, {X(0)}, {mode},
{conditional0, conditional1});
EXPECT(equalsObj<GaussianMixture>(gm));
EXPECT(equalsXML<GaussianMixture>(gm));
EXPECT(equalsBinary<GaussianMixture>(gm));
}
/* ************************************************************************* */
int main() {
TestResult tr;