Merge pull request #1286 from borglab/hybrid/serialization

Hybrid Serialization
release/4.3a0
Varun Agrawal 2022-09-01 20:46:45 -04:00 committed by GitHub
commit 30c913e0f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 197 additions and 15 deletions

View File

@ -48,4 +48,25 @@ namespace gtsam {
return keys & key2;
}
void DiscreteKeys::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
}
bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
if (this->size() != other.size()) {
return false;
}
for (size_t i = 0; i < this->size(); i++) {
if (this->at(i).first != other.at(i).first ||
this->at(i).second != other.at(i).second) {
return false;
}
}
return true;
}
}

View File

@ -21,6 +21,7 @@
#include <gtsam/global_includes.h>
#include <gtsam/inference/Key.h>
#include <boost/serialization/vector.hpp>
#include <map>
#include <string>
#include <vector>
@ -72,15 +73,27 @@ namespace gtsam {
/// Print the keys and cardinalities.
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// Check equality to another DiscreteKeys object.
bool equals(const DiscreteKeys& other, double tol = 0) const;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& boost::serialization::make_nvp(
"DiscreteKeys",
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
}
}; // DiscreteKeys
/// Create a list from two keys
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
}
// traits
template <>
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};
} // namespace gtsam

View File

@ -16,14 +16,29 @@
* @author Duy-Nguyen Ta
*/
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <boost/assign/std/map.hpp>
using namespace boost::assign;
using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
/* ************************************************************************* */
TEST(DisreteKeys, Serialization) {
DiscreteKeys keys;
keys& DiscreteKey(0, 2);
keys& DiscreteKey(1, 3);
keys& DiscreteKey(2, 4);
EXPECT(equalsObj<DiscreteKeys>(keys));
EXPECT(equalsXML<DiscreteKeys>(keys));
EXPECT(equalsBinary<DiscreteKeys>(keys));
}
/* ************************************************************************* */
int main() {
@ -31,4 +46,3 @@ int main() {
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -18,6 +18,7 @@
#pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/global_includes.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
using sharedConditional = boost::shared_ptr<ConditionalType>;
/// @name Standard Constructors
/// @{
/** Construct empty bayes net */
HybridBayesNet() = default;
/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
/// @}
/// @name Testable
/// @{
/** Check equality */
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}
/// print graph
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/// @}
/// @name Standard Interface
/// @{
/// Add HybridConditional to Bayes Net
using Base::add;
@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;
/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on
* the MPE assignment.
*
* @return HybridValues
*/
HybridValues optimize() const;
/**
@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;
/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
/// traits
template <>
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
} // namespace gtsam

View File

@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
VectorValues optimize(const DiscreteValues& assignment) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
/// traits
template <>
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
/**
* @brief Class for Hybrid Bayes tree orphan subtrees.
*

View File

@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
}; // HybridConditional
// traits

View File

@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
bool isContinuous_ = false;
bool isHybrid_ = false;
// TODO(Varun) remove
size_t nrContinuous_ = 0;
protected:
@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
const KeyVector &continuousKeys() const { return continuousKeys_; }
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
}
};
// HybridFactor

View File

@ -18,6 +18,7 @@
* @date December 2021
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
@ -28,6 +29,8 @@
using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -16,6 +16,7 @@
* @date August 2022
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous()));
}
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
using namespace gtsam::serializationTestHelpers;
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
EXPECT(equalsBinary(graph));
}
/* ****************************************************************************/
TEST(Serialization, gaussian_bayes_net) {
// Create an arbitrary Bayes Net
GaussianBayesNet gbn;
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));
std::string serialized = serialize(gbn);
GaussianBayesNet actual;
deserialize(serialized, actual);
EXPECT(assert_equal(gbn, actual));
}
/* ************************************************************************* */
TEST (Serialization, gaussian_bayes_tree) {
const Key x1=1, x2=2, x3=3, x4=4;