Merge pull request #1961 from borglab/serialize-tablefactor
commit
e9e52ad21f
|
@ -24,6 +24,7 @@
|
||||||
|
|
||||||
#include <gtsam/base/Matrix.h>
|
#include <gtsam/base/Matrix.h>
|
||||||
|
|
||||||
|
#include <Eigen/Sparse>
|
||||||
#include <boost/serialization/array.hpp>
|
#include <boost/serialization/array.hpp>
|
||||||
#include <boost/serialization/nvp.hpp>
|
#include <boost/serialization/nvp.hpp>
|
||||||
#include <boost/serialization/split_free.hpp>
|
#include <boost/serialization/split_free.hpp>
|
||||||
|
@ -87,6 +88,45 @@ void serialize(Archive& ar, gtsam::Matrix& m, const unsigned int version) {
|
||||||
split_free(ar, m, version);
|
split_free(ar, m, version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
/// Customized functions for serializing Eigen::SparseVector
|
||||||
|
template <class Archive, typename _Scalar, int _Options, typename _Index>
|
||||||
|
void save(Archive& ar, const Eigen::SparseVector<_Scalar, _Options, _Index>& m,
|
||||||
|
const unsigned int /*version*/) {
|
||||||
|
_Index size = m.size();
|
||||||
|
|
||||||
|
std::vector<std::pair<Eigen::Index, _Scalar>> data;
|
||||||
|
for (typename Eigen::SparseVector<_Scalar, _Options, _Index>::InnerIterator
|
||||||
|
it(m);
|
||||||
|
it; ++it)
|
||||||
|
data.push_back({it.index(), it.value()});
|
||||||
|
|
||||||
|
ar << BOOST_SERIALIZATION_NVP(size);
|
||||||
|
ar << BOOST_SERIALIZATION_NVP(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Archive, typename _Scalar, int _Options, typename _Index>
|
||||||
|
void load(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m,
|
||||||
|
const unsigned int /*version*/) {
|
||||||
|
_Index size;
|
||||||
|
ar >> BOOST_SERIALIZATION_NVP(size);
|
||||||
|
m.resize(size);
|
||||||
|
|
||||||
|
std::vector<std::pair<Eigen::Index, _Scalar>> data;
|
||||||
|
ar >> BOOST_SERIALIZATION_NVP(data);
|
||||||
|
|
||||||
|
for (auto&& d : data) {
|
||||||
|
m.coeffRef(d.first) = d.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Archive, typename _Scalar, int _Options, typename _Index>
|
||||||
|
void serialize(Archive& ar, Eigen::SparseVector<_Scalar, _Options, _Index>& m,
|
||||||
|
const unsigned int version) {
|
||||||
|
split_free(ar, m, version);
|
||||||
|
}
|
||||||
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace serialization
|
} // namespace serialization
|
||||||
} // namespace boost
|
} // namespace boost
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -31,6 +31,12 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
|
#include <gtsam/base/MatrixSerialization.h>
|
||||||
|
|
||||||
|
#include <boost/serialization/nvp.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class DiscreteConditional;
|
class DiscreteConditional;
|
||||||
|
@ -342,6 +348,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(sparse_table_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(denominators_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(sorted_dkeys_);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -32,6 +33,7 @@ BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
|
||||||
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
|
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
|
||||||
|
|
||||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(TableFactor, "gtsam_TableFactor");
|
||||||
|
|
||||||
using ADT = AlgebraicDecisionTree<Key>;
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||||
|
@ -79,6 +81,19 @@ TEST(DiscreteSerialization, DecisionTreeFactor) {
|
||||||
EXPECT(equalsBinary<DecisionTreeFactor>(f));
|
EXPECT(equalsBinary<DecisionTreeFactor>(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check serialization for TableFactor
|
||||||
|
TEST(DiscreteSerialization, TableFactor) {
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
DiscreteKey A(Symbol('x', 1), 3);
|
||||||
|
TableFactor tf(A, "1 2 2");
|
||||||
|
|
||||||
|
EXPECT(equalsObj<TableFactor>(tf));
|
||||||
|
EXPECT(equalsXML<TableFactor>(tf));
|
||||||
|
EXPECT(equalsBinary<TableFactor>(tf));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check serialization for DiscreteConditional & DiscreteDistribution
|
// Check serialization for DiscreteConditional & DiscreteDistribution
|
||||||
TEST(DiscreteSerialization, DiscreteConditional) {
|
TEST(DiscreteSerialization, DiscreteConditional) {
|
||||||
|
|
Loading…
Reference in New Issue