Add Python bindings

release/4.3a0
Fan Jiang 2022-03-25 17:14:00 -06:00
parent 8d888606a3
commit d2dc620b1e
15 changed files with 311 additions and 47 deletions

View File

@ -35,12 +35,23 @@ GaussianMixture::GaussianMixture(
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {}
const GaussianMixture::Conditionals& GaussianMixture::conditionals() {
const GaussianMixture::Conditionals &GaussianMixture::conditionals() {
return conditionals_;
}
GaussianMixture GaussianMixture::FromConditionalList(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList);
return GaussianMixture(continuousFrontals, continuousParents, discreteParents,
dt);
}
/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::addTo(const GaussianMixture::Sum &sum) const {
GaussianMixture::Sum GaussianMixture::addTo(
const GaussianMixture::Sum &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
@ -66,7 +77,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
}
void GaussianMixture::print(const std::string &s,
const KeyFormatter &formatter) const {
const KeyFormatter &formatter) const {
std::cout << s << ": ";
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";
@ -88,4 +99,4 @@ void GaussianMixture::print(const std::string &s,
return rd.str();
});
}
} // namespace gtsam
} // namespace gtsam

View File

@ -60,6 +60,11 @@ class GaussianMixture : public HybridFactor,
/* *******************************************************************************/
Sum wrappedConditionals() const;
static This FromConditionalList(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print(

View File

@ -18,43 +18,52 @@
* @date Mar 12, 2022
*/
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors) : Base(continuousKeys, discreteKeys),
factors_(factors) {}
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return false;
}
void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const {
GaussianMixtureFactor GaussianMixtureFactor::FromFactorList(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factorsList) {
Factors dt(discreteKeys, factorsList);
return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
}
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
factors_.print(
"mixture = ",
[&](Key k) {
return formatter(k);
}, [&](const GaussianFactor::shared_ptr &gf) -> std::string {
"mixture = ", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty()) gf->print("", formatter);
else return {"nullptr"};
if (!gf->empty())
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
}
const GaussianMixtureFactor::Factors& GaussianMixtureFactor::factors() {
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
return factors_;
}
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(const GaussianMixtureFactor::Sum &sum) const {
GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(
const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
@ -74,4 +83,4 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::wrappedFactors() const {
};
return {factors_, wrap};
}
}
} // namespace gtsam

View File

@ -29,6 +29,8 @@ namespace gtsam {
class GaussianFactorGraph;
typedef std::vector<gtsam::GaussianFactor::shared_ptr> GaussianFactorVector;
class GaussianMixtureFactor : public HybridFactor {
public:
using Base = HybridFactor;
@ -42,11 +44,16 @@ class GaussianMixtureFactor : public HybridFactor {
GaussianMixtureFactor() = default;
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const Factors &factors);
const DiscreteKeys &discreteKeys,
const Factors &factors);
using Sum = DecisionTree<Key, GaussianFactorGraph>;
const Factors& factors();
const Factors &factors();
static This FromFactorList(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors);
/* *******************************************************************************/
Sum addTo(const Sum &sum) const;

View File

@ -16,9 +16,9 @@
*/
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include "gtsam/hybrid/HybridFactor.h"
#include "gtsam/inference/Key.h"
#include <gtsam/inference/Key.h>
namespace gtsam {
@ -27,8 +27,9 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents)
: HybridConditional(
CollectKeys({continuousFrontals.begin(), continuousFrontals.end()},
KeyVector {continuousParents.begin(), continuousParents.end()}),
CollectKeys(
{continuousFrontals.begin(), continuousFrontals.end()},
KeyVector{continuousParents.begin(), continuousParents.end()}),
CollectDiscreteKeys(
{discreteFrontals.begin(), discreteFrontals.end()},
{discreteParents.begin(), discreteParents.end()}),

View File

@ -35,10 +35,10 @@ class HybridDiscreteFactor : public HybridFactor {
DiscreteFactor::shared_ptr inner;
// Implicit conversion from a shared ptr of GF
// Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
// Forwarding constructor from concrete JacobianFactor
// Forwarding constructor from concrete DecisionTreeFactor
HybridDiscreteFactor(DecisionTreeFactor &&dtf);
public:

View File

@ -35,6 +35,7 @@
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h>
#include <algorithm>
#include <cstddef>
@ -380,4 +381,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
void HybridFactorGraph::add(JacobianFactor &&factor) {
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(std::move(factor)));
}
void HybridFactorGraph::add(DecisionTreeFactor &&factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(std::move(factor)));
}
void HybridFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}
void HybridFactorGraph::add(JacobianFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor));
}
} // namespace gtsam

View File

@ -32,6 +32,7 @@ class HybridBayesNet;
class HybridEliminationTree;
class HybridBayesTree;
class HybridJunctionTree;
class DecisionTreeFactor;
class JacobianFactor;
@ -98,6 +99,12 @@ class HybridFactorGraph : public FactorGraph<HybridFactor>,
/// Add a factor directly using a shared_ptr.
void add(JacobianFactor&& factor);
void add(DecisionTreeFactor&& factor);
void add(boost::shared_ptr<DecisionTreeFactor> factor);
void add(boost::shared_ptr<JacobianFactor> factor);
};
} // namespace gtsam

134
gtsam/hybrid/hybrid.i Normal file
View File

@ -0,0 +1,134 @@
//*************************************************************************
// hybrid
//*************************************************************************
namespace gtsam {
#include <gtsam/hybrid/HybridFactor.h>
virtual class HybridFactor {
void print(string s = "HybridFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const;
bool empty() const;
size_t size() const;
gtsam::KeyVector keys() const;
};
#include <gtsam/hybrid/GaussianMixtureFactor.h>
class GaussianMixtureFactor : gtsam::HybridFactor {
static GaussianMixtureFactor FromFactorList(
const gtsam::KeyVector& continuousKeys,
const gtsam::DiscreteKeys& discreteKeys,
const std::vector<gtsam::GaussianFactor::shared_ptr>& factorsList);
void print(string s = "GaussianMixtureFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
#include <gtsam/hybrid/GaussianMixture.h>
class GaussianMixture : gtsam::HybridFactor {
static GaussianMixture FromConditionalList(
const gtsam::KeyVector& continuousFrontals,
const gtsam::KeyVector& continuousParents,
const gtsam::DiscreteKeys& discreteParents,
const std::vector<gtsam::GaussianConditional::shared_ptr>&
conditionalsList);
void print(string s = "GaussianMixture\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
#include <gtsam/hybrid/HybridBayesTree.h>
class HybridBayesTreeClique {
HybridBayesTreeClique();
HybridBayesTreeClique(const gtsam::HybridConditional* conditional);
const gtsam::HybridConditional* conditional() const;
bool isRoot() const;
// double evaluate(const gtsam::HybridValues& values) const;
};
#include <gtsam/hybrid/HybridBayesTree.h>
class HybridBayesTree {
HybridBayesTree();
void print(string s = "HybridBayesTree\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridBayesTree& other, double tol = 1e-9) const;
size_t size() const;
bool empty() const;
const HybridBayesTreeClique* operator[](size_t j) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
class HybridBayesNet {
HybridBayesNet();
void add(const gtsam::HybridConditional& s);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridConditional* at(size_t i) const;
void print(string s = "HybridBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridBayesNet& other, double tol = 1e-9) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
#include <gtsam/hybrid/HybridFactorGraph.h>
class HybridFactorGraph {
HybridFactorGraph();
HybridFactorGraph(const gtsam::HybridBayesNet& bayesNet);
// Building the graph
void push_back(const gtsam::HybridFactor* factor);
void push_back(const gtsam::HybridConditional* conditional);
void push_back(const gtsam::HybridFactorGraph& graph);
void push_back(const gtsam::HybridBayesNet& bayesNet);
void push_back(const gtsam::HybridBayesTree& bayesTree);
void push_back(const gtsam::GaussianMixtureFactor* gmm);
void add(gtsam::DecisionTreeFactor* factor);
void add(gtsam::JacobianFactor* factor);
bool empty() const;
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridFactor* at(size_t i) const;
void print(string s = "") const;
bool equals(const gtsam::HybridFactorGraph& fg, double tol = 1e-9) const;
gtsam::HybridBayesNet* eliminateSequential();
gtsam::HybridBayesNet* eliminateSequential(
gtsam::Ordering::OrderingType type);
gtsam::HybridBayesNet* eliminateSequential(const gtsam::Ordering& ordering);
pair<gtsam::HybridBayesNet*, gtsam::HybridFactorGraph*>
eliminatePartialSequential(const gtsam::Ordering& ordering);
gtsam::HybridBayesTree* eliminateMultifrontal();
gtsam::HybridBayesTree* eliminateMultifrontal(
gtsam::Ordering::OrderingType type);
gtsam::HybridBayesTree* eliminateMultifrontal(
const gtsam::Ordering& ordering);
pair<gtsam::HybridBayesTree*, gtsam::HybridFactorGraph*>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
};
} // namespace gtsam

View File

@ -133,15 +133,19 @@ TEST(HybridFactorGraph, eliminateFullMultifrontalSimple) {
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
// DecisionTree<Key, GaussianFactor::shared_ptr> dt(
// C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
// boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
// hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
hfg.add(GaussianMixtureFactor::FromFactorList(
{X(1)}, {{C(1), 2}},
{boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones())}));
hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
// hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8})));
hfg.add(HybridDiscreteFactor(
DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
hfg.add(DecisionTreeFactor(c1, {2, 8}));
hfg.add(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"));
// hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(2), 2}, {C(3), 2}}, "1
// 2 3 4"))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(3), 2},
// {C(1), 2}}, "1 2 2 1")));
@ -199,11 +203,14 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) {
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
{
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
C(0), boost::make_shared<JacobianFactor>(X(0), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(0), I_3x3, Vector3::Ones()));
// DecisionTree<Key, GaussianFactor::shared_ptr> dt(
// C(0), boost::make_shared<JacobianFactor>(X(0), I_3x3, Z_3x1),
// boost::make_shared<JacobianFactor>(X(0), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(0)}, {{C(0), 2}}, dt));
hfg.add(GaussianMixtureFactor::FromFactorList(
{X(0)}, {{C(0), 2}},
{boost::make_shared<JacobianFactor>(X(0), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(0), I_3x3, Vector3::Ones())}));
DecisionTree<Key, GaussianFactor::shared_ptr> dt1(
C(1), boost::make_shared<JacobianFactor>(X(2), I_3x3, Z_3x1),

View File

@ -9,6 +9,7 @@ namespace gtsam {
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/symbolic/SymbolicFactorGraph.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/inference/Key.h>
@ -106,36 +107,36 @@ class Ordering {
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}>
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}>
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}>
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}>
static gtsam::Ordering ColamdConstrainedLast(
const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainLast,
bool forceOrder = false);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}>
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}>
static gtsam::Ordering ColamdConstrainedFirst(
const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainFirst,
bool forceOrder = false);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}>
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}>
static gtsam::Ordering Natural(const FACTOR_GRAPH& graph);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}>
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}>
static gtsam::Ordering Metis(const FACTOR_GRAPH& graph);
template <
FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph,
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}>
gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridFactorGraph}>
static gtsam::Ordering Create(gtsam::Ordering::OrderingType orderingType,
const FACTOR_GRAPH& graph);

View File

@ -65,6 +65,7 @@ set(interface_headers
${PROJECT_SOURCE_DIR}/gtsam/sfm/sfm.i
${PROJECT_SOURCE_DIR}/gtsam/navigation/navigation.i
${PROJECT_SOURCE_DIR}/gtsam/basis/basis.i
${PROJECT_SOURCE_DIR}/gtsam/hybrid/hybrid.i
)
set(GTSAM_PYTHON_TARGET gtsam_py)

View File

@ -0,0 +1,14 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
* automatic STL binding, such that the raw objects can be accessed in Python.
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
#include <pybind11/stl.h>

View File

@ -0,0 +1,4 @@
py::bind_vector<std::vector<gtsam::GaussianFactor::shared_ptr> >(m_, "GaussianFactorVector");
py::implicitly_convertible<py::list, std::vector<gtsam::GaussianFactor::shared_ptr> >();

View File

@ -0,0 +1,49 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Hybrid Factor Graphs.
Author: Fan Jiang
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest
import gtsam
import numpy as np
from gtsam.symbol_shorthand import X, C
from gtsam.utils.test_case import GtsamTestCase
class TestHybridFactorGraph(GtsamTestCase):
def test_create(self):
noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2))
# print(dk.at(0))
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
noiseModel)
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
noiseModel)
gmf = gtsam.GaussianMixtureFactor.FromFactorList([X(0)], dk,
[jf1, jf2])
hfg = gtsam.HybridFactorGraph()
hfg.add(jf1)
hfg.add(jf2)
hfg.push_back(gmf)
hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(
hfg, [C(0)])).print()
if __name__ == "__main__":
unittest.main()