commit
30c913e0f1
|
|
@ -48,4 +48,25 @@ namespace gtsam {
|
|||
return keys & key2;
|
||||
}
|
||||
|
||||
void DiscreteKeys::print(const std::string& s,
|
||||
const KeyFormatter& keyFormatter) const {
|
||||
for (auto&& dkey : *this) {
|
||||
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
|
||||
if (this->size() != other.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
if (this->at(i).first != other.at(i).first ||
|
||||
this->at(i).second != other.at(i).second) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
#include <gtsam/global_includes.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
#include <boost/serialization/vector.hpp>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
|
@ -72,15 +73,27 @@ namespace gtsam {
|
|||
|
||||
/// Print the keys and cardinalities.
|
||||
void print(const std::string& s = "",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
for (auto&& dkey : *this) {
|
||||
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
|
||||
<< std::endl;
|
||||
}
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// Check equality to another DiscreteKeys object.
|
||||
bool equals(const DiscreteKeys& other, double tol = 0) const;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& boost::serialization::make_nvp(
|
||||
"DiscreteKeys",
|
||||
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
|
||||
}
|
||||
|
||||
}; // DiscreteKeys
|
||||
|
||||
/// Create a list from two keys
|
||||
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||
}
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -16,14 +16,29 @@
|
|||
* @author Duy-Nguyen Ta
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
|
||||
#include <boost/assign/std/map.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DisreteKeys, Serialization) {
|
||||
DiscreteKeys keys;
|
||||
keys& DiscreteKey(0, 2);
|
||||
keys& DiscreteKey(1, 3);
|
||||
keys& DiscreteKey(2, 4);
|
||||
|
||||
EXPECT(equalsObj<DiscreteKeys>(keys));
|
||||
EXPECT(equalsXML<DiscreteKeys>(keys));
|
||||
EXPECT(equalsBinary<DiscreteKeys>(keys));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
|
|
@ -31,4 +46,3 @@ int main() {
|
|||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/global_includes.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
|
@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty bayes net */
|
||||
HybridBayesNet() = default;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This &bn, double tol = 1e-9) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/// print graph
|
||||
void print(
|
||||
const std::string &s = "",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Add HybridConditional to Bayes Net
|
||||
using Base::add;
|
||||
|
|
@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Solve the HybridBayesNet by back-substitution.
|
||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
||||
/// put this method there?
|
||||
/**
|
||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||
* discrete variables and then optimizing the continuous variables based on
|
||||
* the MPE assignment.
|
||||
*
|
||||
* @return HybridValues
|
||||
*/
|
||||
HybridValues optimize() const;
|
||||
|
||||
/**
|
||||
|
|
@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @return Values
|
||||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
/// traits
|
||||
template <>
|
||||
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
/// traits
|
||||
template <>
|
||||
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
|
||||
|
||||
/**
|
||||
* @brief Class for Hybrid Bayes tree orphan subtrees.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional
|
|||
/// Get the type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
}
|
||||
|
||||
}; // HybridConditional
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
bool isContinuous_ = false;
|
||||
bool isHybrid_ = false;
|
||||
|
||||
// TODO(Varun) remove
|
||||
size_t nrContinuous_ = 0;
|
||||
|
||||
protected:
|
||||
|
|
@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
|
||||
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
|
||||
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
|
||||
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
|
||||
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
|
||||
}
|
||||
};
|
||||
// HybridFactor
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
|
||||
|
|
@ -28,6 +29,8 @@
|
|||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
|
@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
* @date August 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
|
@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridBayesTree, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
|
|||
EXPECT(equalsBinary(graph));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
TEST(Serialization, gaussian_bayes_net) {
|
||||
// Create an arbitrary Bayes Net
|
||||
GaussianBayesNet gbn;
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
|
||||
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
|
||||
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
|
||||
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
|
||||
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
|
||||
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
|
||||
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
|
||||
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));
|
||||
|
||||
std::string serialized = serialize(gbn);
|
||||
GaussianBayesNet actual;
|
||||
deserialize(serialized, actual);
|
||||
EXPECT(assert_equal(gbn, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST (Serialization, gaussian_bayes_tree) {
|
||||
const Key x1=1, x2=2, x3=3, x4=4;
|
||||
|
|
|
|||
Loading…
Reference in New Issue