commit
30c913e0f1
|
|
@ -48,4 +48,25 @@ namespace gtsam {
|
||||||
return keys & key2;
|
return keys & key2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DiscreteKeys::print(const std::string& s,
|
||||||
|
const KeyFormatter& keyFormatter) const {
|
||||||
|
for (auto&& dkey : *this) {
|
||||||
|
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
|
||||||
|
if (this->size() != other.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
if (this->at(i).first != other.at(i).first ||
|
||||||
|
this->at(i).second != other.at(i).second) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
#include <boost/serialization/vector.hpp>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -72,15 +73,27 @@ namespace gtsam {
|
||||||
|
|
||||||
/// Print the keys and cardinalities.
|
/// Print the keys and cardinalities.
|
||||||
void print(const std::string& s = "",
|
void print(const std::string& s = "",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||||
for (auto&& dkey : *this) {
|
|
||||||
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
|
/// Check equality to another DiscreteKeys object.
|
||||||
<< std::endl;
|
bool equals(const DiscreteKeys& other, double tol = 0) const;
|
||||||
}
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& boost::serialization::make_nvp(
|
||||||
|
"DiscreteKeys",
|
||||||
|
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // DiscreteKeys
|
}; // DiscreteKeys
|
||||||
|
|
||||||
/// Create a list from two keys
|
/// Create a list from two keys
|
||||||
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
|
||||||
}
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,29 @@
|
||||||
* @author Duy-Nguyen Ta
|
* @author Duy-Nguyen Ta
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
using namespace gtsam::serializationTestHelpers;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DisreteKeys, Serialization) {
|
||||||
|
DiscreteKeys keys;
|
||||||
|
keys& DiscreteKey(0, 2);
|
||||||
|
keys& DiscreteKey(1, 3);
|
||||||
|
keys& DiscreteKey(2, 4);
|
||||||
|
|
||||||
|
EXPECT(equalsObj<DiscreteKeys>(keys));
|
||||||
|
EXPECT(equalsXML<DiscreteKeys>(keys));
|
||||||
|
EXPECT(equalsBinary<DiscreteKeys>(keys));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
|
|
@ -31,4 +46,3 @@ int main() {
|
||||||
return TestRegistry::runAllTests(tr);
|
return TestRegistry::runAllTests(tr);
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/global_includes.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
|
@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
||||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/** Construct empty bayes net */
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() = default;
|
||||||
|
|
||||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
/// @}
|
||||||
HybridBayesNet prune(
|
/// @name Testable
|
||||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
/// @{
|
||||||
|
|
||||||
|
/** Check equality */
|
||||||
|
bool equals(const This &bn, double tol = 1e-9) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// print graph
|
||||||
|
void print(
|
||||||
|
const std::string &s = "",
|
||||||
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||||
|
Base::print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Add HybridConditional to Bayes Net
|
/// Add HybridConditional to Bayes Net
|
||||||
using Base::add;
|
using Base::add;
|
||||||
|
|
@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/// Solve the HybridBayesNet by back-substitution.
|
/**
|
||||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||||
/// put this method there?
|
* discrete variables and then optimizing the continuous variables based on
|
||||||
|
* the MPE assignment.
|
||||||
|
*
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
HybridValues optimize() const;
|
HybridValues optimize() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @return Values
|
* @return Values
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
|
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||||
|
HybridBayesNet prune(
|
||||||
|
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
||||||
VectorValues optimize(const DiscreteValues& assignment) const;
|
VectorValues optimize(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Class for Hybrid Bayes tree orphan subtrees.
|
* @brief Class for Hybrid Bayes tree orphan subtrees.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
}
|
||||||
|
|
||||||
}; // HybridConditional
|
}; // HybridConditional
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
bool isContinuous_ = false;
|
bool isContinuous_ = false;
|
||||||
bool isHybrid_ = false;
|
bool isHybrid_ = false;
|
||||||
|
|
||||||
|
// TODO(Varun) remove
|
||||||
size_t nrContinuous_ = 0;
|
size_t nrContinuous_ = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// HybridFactor
|
// HybridFactor
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -28,6 +29,8 @@
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
using namespace gtsam::serializationTestHelpers;
|
||||||
|
|
||||||
using noiseModel::Isotropic;
|
using noiseModel::Isotropic;
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
|
@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridBayesNet serialization.
|
||||||
|
TEST(HybridBayesNet, Serialization) {
|
||||||
|
Switching s(4);
|
||||||
|
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||||
|
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||||
|
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
* @date August 2022
|
* @date August 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||||
|
|
@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridBayesTree serialization.
|
||||||
|
TEST(HybridBayesTree, Serialization) {
|
||||||
|
Switching s(4);
|
||||||
|
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesTree hbt =
|
||||||
|
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||||
|
|
||||||
|
using namespace gtsam::serializationTestHelpers;
|
||||||
|
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||||
|
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||||
|
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
|
||||||
EXPECT(equalsBinary(graph));
|
EXPECT(equalsBinary(graph));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
TEST(Serialization, gaussian_bayes_net) {
|
||||||
|
// Create an arbitrary Bayes Net
|
||||||
|
GaussianBayesNet gbn;
|
||||||
|
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||||
|
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
|
||||||
|
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
|
||||||
|
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
|
||||||
|
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||||
|
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
|
||||||
|
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
|
||||||
|
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
|
||||||
|
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||||
|
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
|
||||||
|
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
|
||||||
|
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||||
|
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
|
||||||
|
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
|
||||||
|
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
|
||||||
|
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));
|
||||||
|
|
||||||
|
std::string serialized = serialize(gbn);
|
||||||
|
GaussianBayesNet actual;
|
||||||
|
deserialize(serialized, actual);
|
||||||
|
EXPECT(assert_equal(gbn, actual));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST (Serialization, gaussian_bayes_tree) {
|
TEST (Serialization, gaussian_bayes_tree) {
|
||||||
const Key x1=1, x2=2, x3=3, x4=4;
|
const Key x1=1, x2=2, x3=3, x4=4;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue