From d2dc620b1e3031fc35c5ebe9d5030235b54ebf68 Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Fri, 25 Mar 2022 17:14:00 -0600 Subject: [PATCH] Add Python bindings --- gtsam/hybrid/GaussianMixture.cpp | 19 ++- gtsam/hybrid/GaussianMixture.h | 5 + gtsam/hybrid/GaussianMixtureFactor.cpp | 45 ++++--- gtsam/hybrid/GaussianMixtureFactor.h | 11 +- gtsam/hybrid/HybridConditional.cpp | 9 +- gtsam/hybrid/HybridDiscreteFactor.h | 4 +- gtsam/hybrid/HybridFactorGraph.cpp | 14 ++ gtsam/hybrid/HybridFactorGraph.h | 7 + gtsam/hybrid/hybrid.i | 134 +++++++++++++++++++ gtsam/hybrid/tests/testHybridConditional.cpp | 29 ++-- gtsam/inference/inference.i | 13 +- python/CMakeLists.txt | 1 + python/gtsam/preamble/hybrid.h | 14 ++ python/gtsam/specializations/hybrid.h | 4 + python/gtsam/tests/test_HybridFactorGraph.py | 49 +++++++ 15 files changed, 311 insertions(+), 47 deletions(-) create mode 100644 gtsam/hybrid/hybrid.i create mode 100644 python/gtsam/preamble/hybrid.h create mode 100644 python/gtsam/specializations/hybrid.h create mode 100644 python/gtsam/tests/test_HybridFactorGraph.py diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 26668da59..bc674966c 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -35,12 +35,23 @@ GaussianMixture::GaussianMixture( BaseConditional(continuousFrontals.size()), conditionals_(conditionals) {} -const GaussianMixture::Conditionals& GaussianMixture::conditionals() { +const GaussianMixture::Conditionals &GaussianMixture::conditionals() { return conditionals_; } +GaussianMixture GaussianMixture::FromConditionalList( + const KeyVector &continuousFrontals, const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const std::vector &conditionalsList) { + Conditionals dt(discreteParents, conditionalsList); + + return GaussianMixture(continuousFrontals, continuousParents, discreteParents, + dt); +} + /* *******************************************************************************/ -GaussianMixture::Sum GaussianMixture::addTo(const GaussianMixture::Sum &sum) const { +GaussianMixture::Sum GaussianMixture::addTo( + const GaussianMixture::Sum &sum) const { using Y = GaussianFactorGraph; auto add = [](const Y &graph1, const Y &graph2) { auto result = graph1; @@ -66,7 +77,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { } void GaussianMixture::print(const std::string &s, - const KeyFormatter &formatter) const { + const KeyFormatter &formatter) const { std::cout << s << ": "; if (isContinuous_) std::cout << "Cont. "; if (isDiscrete_) std::cout << "Disc. "; @@ -88,4 +99,4 @@ void GaussianMixture::print(const std::string &s, return rd.str(); }); } -} // namespace gtsam \ No newline at end of file +} // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 9879e80bd..4412e741c 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -60,6 +60,11 @@ class GaussianMixture : public HybridFactor, /* *******************************************************************************/ Sum wrappedConditionals() const; + static This FromConditionalList( + const KeyVector &continuousFrontals, const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const std::vector &conditionals); + bool equals(const HybridFactor &lf, double tol = 1e-9) const override; void print( diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index 61cfb1d70..3963e675e 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -18,43 +18,52 @@ * @date Mar 12, 2022 */ -#include - -#include -#include -#include #include +#include +#include +#include +#include namespace gtsam { GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys, - const Factors &factors) : Base(continuousKeys, discreteKeys), - factors_(factors) {} + const DiscreteKeys &discreteKeys, + const Factors &factors) + : Base(continuousKeys, discreteKeys), factors_(factors) {} bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { return false; } -void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { +GaussianMixtureFactor GaussianMixtureFactor::FromFactorList( + const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, + const std::vector &factorsList) { + Factors dt(discreteKeys, factorsList); + + return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); +} + +void GaussianMixtureFactor::print(const std::string &s, + const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); factors_.print( - "mixture = ", - [&](Key k) { - return formatter(k); - }, [&](const GaussianFactor::shared_ptr &gf) -> std::string { + "mixture = ", [&](Key k) { return formatter(k); }, + [&](const GaussianFactor::shared_ptr &gf) -> std::string { RedirectCout rd; - if (!gf->empty()) gf->print("", formatter); - else return {"nullptr"}; + if (!gf->empty()) + gf->print("", formatter); + else + return {"nullptr"}; return rd.str(); }); } -const GaussianMixtureFactor::Factors& GaussianMixtureFactor::factors() { +const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { return factors_; } /* *******************************************************************************/ -GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(const GaussianMixtureFactor::Sum &sum) const { +GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo( + const GaussianMixtureFactor::Sum &sum) const { using Y = GaussianFactorGraph; auto add = [](const Y &graph1, const Y &graph2) { auto result = graph1; @@ -74,4 +83,4 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::wrappedFactors() const { }; return {factors_, wrap}; } -} +} // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 4cd5e71de..57a0cca03 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -29,6 +29,8 @@ namespace gtsam { class GaussianFactorGraph; +typedef std::vector GaussianFactorVector; + class GaussianMixtureFactor : public HybridFactor { public: using Base = HybridFactor; @@ -42,11 +44,16 @@ class GaussianMixtureFactor : public HybridFactor { GaussianMixtureFactor() = default; GaussianMixtureFactor(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys, const Factors &factors); + const DiscreteKeys &discreteKeys, + const Factors &factors); using Sum = DecisionTree; - const Factors& factors(); + const Factors &factors(); + + static This FromFactorList( + const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, + const std::vector &factors); /* *******************************************************************************/ Sum addTo(const Sum &sum) const; diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index ed4b01608..48bee192c 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -16,9 +16,9 @@ */ #include +#include #include -#include "gtsam/hybrid/HybridFactor.h" -#include "gtsam/inference/Key.h" +#include namespace gtsam { @@ -27,8 +27,9 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents) : HybridConditional( - CollectKeys({continuousFrontals.begin(), continuousFrontals.end()}, - KeyVector {continuousParents.begin(), continuousParents.end()}), + CollectKeys( + {continuousFrontals.begin(), continuousFrontals.end()}, + KeyVector{continuousParents.begin(), continuousParents.end()}), CollectDiscreteKeys( {discreteFrontals.begin(), discreteFrontals.end()}, {discreteParents.begin(), discreteParents.end()}), diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index c55c5ecf0..809510eac 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -35,10 +35,10 @@ class HybridDiscreteFactor : public HybridFactor { DiscreteFactor::shared_ptr inner; - // Implicit conversion from a shared ptr of GF + // Implicit conversion from a shared ptr of DF HybridDiscreteFactor(DiscreteFactor::shared_ptr other); - // Forwarding constructor from concrete JacobianFactor + // Forwarding constructor from concrete DecisionTreeFactor HybridDiscreteFactor(DecisionTreeFactor &&dtf); public: diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 9e7eebc57..6c7599a12 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -380,4 +381,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { void HybridFactorGraph::add(JacobianFactor &&factor) { FactorGraph::add(boost::make_shared(std::move(factor))); } + +void HybridFactorGraph::add(DecisionTreeFactor &&factor) { + FactorGraph::add(boost::make_shared(std::move(factor))); +} + +void HybridFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { + FactorGraph::add(boost::make_shared(factor)); +} + +void HybridFactorGraph::add(JacobianFactor::shared_ptr factor) { + FactorGraph::add(boost::make_shared(factor)); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index 21332d80a..bfd6a8690 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -32,6 +32,7 @@ class HybridBayesNet; class HybridEliminationTree; class HybridBayesTree; class HybridJunctionTree; +class DecisionTreeFactor; class JacobianFactor; @@ -98,6 +99,12 @@ class HybridFactorGraph : public FactorGraph, /// Add a factor directly using a shared_ptr. void add(JacobianFactor&& factor); + + void add(DecisionTreeFactor&& factor); + + void add(boost::shared_ptr factor); + + void add(boost::shared_ptr factor); }; } // namespace gtsam diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i new file mode 100644 index 000000000..c84cd82c5 --- /dev/null +++ b/gtsam/hybrid/hybrid.i @@ -0,0 +1,134 @@ +//************************************************************************* +// hybrid +//************************************************************************* + +namespace gtsam { + +#include +virtual class HybridFactor { + void print(string s = "HybridFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const; + bool empty() const; + size_t size() const; + gtsam::KeyVector keys() const; +}; + +#include +class GaussianMixtureFactor : gtsam::HybridFactor { + static GaussianMixtureFactor FromFactorList( + const gtsam::KeyVector& continuousKeys, + const gtsam::DiscreteKeys& discreteKeys, + const std::vector& factorsList); + + void print(string s = "GaussianMixtureFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class GaussianMixture : gtsam::HybridFactor { + static GaussianMixture FromConditionalList( + const gtsam::KeyVector& continuousFrontals, + const gtsam::KeyVector& continuousParents, + const gtsam::DiscreteKeys& discreteParents, + const std::vector& + conditionalsList); + + void print(string s = "GaussianMixture\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class HybridBayesTreeClique { + HybridBayesTreeClique(); + HybridBayesTreeClique(const gtsam::HybridConditional* conditional); + const gtsam::HybridConditional* conditional() const; + bool isRoot() const; + // double evaluate(const gtsam::HybridValues& values) const; +}; + +#include +class HybridBayesTree { + HybridBayesTree(); + void print(string s = "HybridBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const HybridBayesTreeClique* operator[](size_t j) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +class HybridBayesNet { + HybridBayesNet(); + void add(const gtsam::HybridConditional& s); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::HybridConditional* at(size_t i) const; + void print(string s = "HybridBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridBayesNet& other, double tol = 1e-9) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; +}; + +#include +class HybridFactorGraph { + HybridFactorGraph(); + HybridFactorGraph(const gtsam::HybridBayesNet& bayesNet); + + // Building the graph + void push_back(const gtsam::HybridFactor* factor); + void push_back(const gtsam::HybridConditional* conditional); + void push_back(const gtsam::HybridFactorGraph& graph); + void push_back(const gtsam::HybridBayesNet& bayesNet); + void push_back(const gtsam::HybridBayesTree& bayesTree); + void push_back(const gtsam::GaussianMixtureFactor* gmm); + + void add(gtsam::DecisionTreeFactor* factor); + void add(gtsam::JacobianFactor* factor); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::HybridFactor* at(size_t i) const; + + void print(string s = "") const; + bool equals(const gtsam::HybridFactorGraph& fg, double tol = 1e-9) const; + + gtsam::HybridBayesNet* eliminateSequential(); + gtsam::HybridBayesNet* eliminateSequential( + gtsam::Ordering::OrderingType type); + gtsam::HybridBayesNet* eliminateSequential(const gtsam::Ordering& ordering); + pair + eliminatePartialSequential(const gtsam::Ordering& ordering); + + gtsam::HybridBayesTree* eliminateMultifrontal(); + gtsam::HybridBayesTree* eliminateMultifrontal( + gtsam::Ordering::OrderingType type); + gtsam::HybridBayesTree* eliminateMultifrontal( + const gtsam::Ordering& ordering); + pair + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index 95703b0af..97685b87d 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -133,15 +133,19 @@ TEST(HybridFactorGraph, eliminateFullMultifrontalSimple) { hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); - DecisionTree dt( - C(1), boost::make_shared(X(1), I_3x3, Z_3x1), - boost::make_shared(X(1), I_3x3, Vector3::Ones())); + // DecisionTree dt( + // C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + // boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + // hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + hfg.add(GaussianMixtureFactor::FromFactorList( + {X(1)}, {{C(1), 2}}, + {boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())})); - hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt)); - hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); - hfg.add(HybridDiscreteFactor( - DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); + hfg.add(DecisionTreeFactor(c1, {2, 8})); + hfg.add(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")); // hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(2), 2}, {C(3), 2}}, "1 // 2 3 4"))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(3), 2}, // {C(1), 2}}, "1 2 2 1"))); @@ -199,11 +203,14 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1)); { - DecisionTree dt( - C(0), boost::make_shared(X(0), I_3x3, Z_3x1), - boost::make_shared(X(0), I_3x3, Vector3::Ones())); + // DecisionTree dt( + // C(0), boost::make_shared(X(0), I_3x3, Z_3x1), + // boost::make_shared(X(0), I_3x3, Vector3::Ones())); - hfg.add(GaussianMixtureFactor({X(0)}, {{C(0), 2}}, dt)); + hfg.add(GaussianMixtureFactor::FromFactorList( + {X(0)}, {{C(0), 2}}, + {boost::make_shared(X(0), I_3x3, Z_3x1), + boost::make_shared(X(0), I_3x3, Vector3::Ones())})); DecisionTree dt1( C(1), boost::make_shared(X(2), I_3x3, Z_3x1), diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i index e7b074ec4..bbf1b2daa 100644 --- a/gtsam/inference/inference.i +++ b/gtsam/inference/inference.i @@ -9,6 +9,7 @@ namespace gtsam { #include #include #include +#include #include @@ -106,36 +107,36 @@ class Ordering { template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}> static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}> static gtsam::Ordering ColamdConstrainedLast( const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainLast, bool forceOrder = false); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}> static gtsam::Ordering ColamdConstrainedFirst( const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainFirst, bool forceOrder = false); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}> static gtsam::Ordering Natural(const FACTOR_GRAPH& graph); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}> static gtsam::Ordering Metis(const FACTOR_GRAPH& graph); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}> static gtsam::Ordering Create(gtsam::Ordering::OrderingType orderingType, const FACTOR_GRAPH& graph); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 85ddc7b6d..5059ac95d 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -65,6 +65,7 @@ set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/sfm/sfm.i ${PROJECT_SOURCE_DIR}/gtsam/navigation/navigation.i ${PROJECT_SOURCE_DIR}/gtsam/basis/basis.i + ${PROJECT_SOURCE_DIR}/gtsam/hybrid/hybrid.i ) set(GTSAM_PYTHON_TARGET gtsam_py) diff --git a/python/gtsam/preamble/hybrid.h b/python/gtsam/preamble/hybrid.h new file mode 100644 index 000000000..56a07cfdd --- /dev/null +++ b/python/gtsam/preamble/hybrid.h @@ -0,0 +1,14 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +#include diff --git a/python/gtsam/specializations/hybrid.h b/python/gtsam/specializations/hybrid.h new file mode 100644 index 000000000..bede6d86c --- /dev/null +++ b/python/gtsam/specializations/hybrid.h @@ -0,0 +1,4 @@ + +py::bind_vector >(m_, "GaussianFactorVector"); + +py::implicitly_convertible >(); diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py new file mode 100644 index 000000000..d8b0f1f33 --- /dev/null +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -0,0 +1,49 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Hybrid Factor Graphs. +Author: Fan Jiang +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam.symbol_shorthand import X, C +from gtsam.utils.test_case import GtsamTestCase + + +class TestHybridFactorGraph(GtsamTestCase): + def test_create(self): + noiseModel = gtsam.noiseModel.Unit.Create(3) + dk = gtsam.DiscreteKeys() + dk.push_back((C(0), 2)) + + # print(dk.at(0)) + jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), + noiseModel) + jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), + noiseModel) + + gmf = gtsam.GaussianMixtureFactor.FromFactorList([X(0)], dk, + [jf1, jf2]) + + hfg = gtsam.HybridFactorGraph() + hfg.add(jf1) + hfg.add(jf2) + hfg.push_back(gmf) + + hfg.eliminateSequential( + gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph( + hfg, [C(0)])).print() + + +if __name__ == "__main__": + unittest.main()