Add wrapping for hybrid nonlinear
parent
ff20a50163
commit
1bb7a00000
|
@ -26,6 +26,12 @@ void HybridNonlinearFactorGraph::add(
|
||||||
FactorGraph::add(boost::make_shared<HybridNonlinearFactor>(factor));
|
FactorGraph::add(boost::make_shared<HybridNonlinearFactor>(factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void HybridNonlinearFactorGraph::add(
|
||||||
|
boost::shared_ptr<DiscreteFactor> factor) {
|
||||||
|
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridNonlinearFactorGraph::print(const std::string& s,
|
void HybridNonlinearFactorGraph::print(const std::string& s,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter& keyFormatter) const {
|
||||||
|
|
|
@ -112,6 +112,9 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
||||||
/// Add a nonlinear factor as a shared ptr.
|
/// Add a nonlinear factor as a shared ptr.
|
||||||
void add(boost::shared_ptr<NonlinearFactor> factor);
|
void add(boost::shared_ptr<NonlinearFactor> factor);
|
||||||
|
|
||||||
|
/// Add a discrete factor as a shared ptr.
|
||||||
|
void add(boost::shared_ptr<DiscreteFactor> factor);
|
||||||
|
|
||||||
/// Print the factor graph.
|
/// Print the factor graph.
|
||||||
void print(
|
void print(
|
||||||
const std::string& s = "HybridNonlinearFactorGraph",
|
const std::string& s = "HybridNonlinearFactorGraph",
|
||||||
|
|
|
@ -39,7 +39,17 @@ virtual class HybridConditional {
|
||||||
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
|
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
|
||||||
size_t nrFrontals() const;
|
size_t nrFrontals() const;
|
||||||
size_t nrParents() const;
|
size_t nrParents() const;
|
||||||
Factor* inner();
|
gtsam::Factor* inner();
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||||
|
virtual class HybridDiscreteFactor {
|
||||||
|
HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf);
|
||||||
|
void print(string s = "HybridDiscreteFactor\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
|
||||||
|
gtsam::Factor* inner();
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
@ -132,6 +142,7 @@ class HybridGaussianFactorGraph {
|
||||||
void add(gtsam::JacobianFactor* factor);
|
void add(gtsam::JacobianFactor* factor);
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
|
void remove(size_t i);
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
gtsam::KeySet keys() const;
|
gtsam::KeySet keys() const;
|
||||||
const gtsam::HybridFactor* at(size_t i) const;
|
const gtsam::HybridFactor* at(size_t i) const;
|
||||||
|
@ -159,4 +170,50 @@ class HybridGaussianFactorGraph {
|
||||||
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
|
class HybridNonlinearFactorGraph {
|
||||||
|
HybridNonlinearFactorGraph();
|
||||||
|
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
|
||||||
|
void push_back(gtsam::HybridFactor* factor);
|
||||||
|
void push_back(gtsam::NonlinearFactor* factor);
|
||||||
|
void push_back(gtsam::HybridDiscreteFactor* factor);
|
||||||
|
void add(gtsam::NonlinearFactor* factor);
|
||||||
|
void add(gtsam::DiscreteFactor* factor);
|
||||||
|
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;
|
||||||
|
|
||||||
|
bool empty() const;
|
||||||
|
void remove(size_t i);
|
||||||
|
size_t size() const;
|
||||||
|
gtsam::KeySet keys() const;
|
||||||
|
const gtsam::HybridFactor* at(size_t i) const;
|
||||||
|
|
||||||
|
void print(string s = "HybridNonlinearFactorGraph\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/MixtureFactor.h>
|
||||||
|
class MixtureFactor : gtsam::HybridFactor {
|
||||||
|
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||||
|
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false);
|
||||||
|
|
||||||
|
template <FACTOR = {gtsam::NonlinearFactor}>
|
||||||
|
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||||
|
const std::vector<FACTOR*>& factors,
|
||||||
|
bool normalized = false);
|
||||||
|
|
||||||
|
double error(const gtsam::Values& continuousVals,
|
||||||
|
const gtsam::DiscreteValues& discreteVals) const;
|
||||||
|
|
||||||
|
double nonlinearFactorLogNormalizingConstant(const gtsam::NonlinearFactor* factor,
|
||||||
|
const gtsam::Values& values) const;
|
||||||
|
|
||||||
|
GaussianMixtureFactor* linearize(
|
||||||
|
const gtsam::Values& continuousVals) const;
|
||||||
|
|
||||||
|
void print(string s = "MixtureFactor\n",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -10,5 +10,12 @@
|
||||||
* Without this they will be automatically converted to a Python object, and all
|
* Without this they will be automatically converted to a Python object, and all
|
||||||
* mutations on Python side will not be reflected on C++.
|
* mutations on Python side will not be reflected on C++.
|
||||||
*/
|
*/
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#ifdef GTSAM_ALLOCATOR_TBB
|
||||||
|
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key, tbb::tbb_allocator<gtsam::Key>>);
|
||||||
|
#else
|
||||||
|
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::Key>);
|
||||||
|
#endif
|
||||||
|
|
||||||
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::GaussianFactor::shared_ptr>);
|
PYBIND11_MAKE_OPAQUE(std::vector<gtsam::GaussianFactor::shared_ptr>);
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""
|
||||||
|
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
|
||||||
|
Atlanta, Georgia 30332-0415
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
See LICENSE for the license information
|
||||||
|
|
||||||
|
Unit tests for Hybrid Nonlinear Factor Graphs.
|
||||||
|
Author: Fan Jiang
|
||||||
|
"""
|
||||||
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import gtsam
|
||||||
|
import numpy as np
|
||||||
|
from gtsam.symbol_shorthand import C, X
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||||
|
"""Unit tests for HybridGaussianFactorGraph."""
|
||||||
|
|
||||||
|
def test_nonlinear_hybrid(self):
|
||||||
|
nlfg = gtsam.HybridNonlinearFactorGraph()
|
||||||
|
dk = gtsam.DiscreteKeys()
|
||||||
|
dk.push_back((10, 2))
|
||||||
|
nlfg.add(gtsam.BetweenFactorPoint3(1, 2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([1, 1, 1])))
|
||||||
|
nlfg.add(
|
||||||
|
gtsam.PriorFactorPoint3(2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([0.5, 0.5, 0.5])))
|
||||||
|
nlfg.push_back(
|
||||||
|
gtsam.MixtureFactor([1], dk, [
|
||||||
|
gtsam.PriorFactorPoint3(1, gtsam.Point3(0, 0, 0),
|
||||||
|
gtsam.noiseModel.Unit.Create(3)),
|
||||||
|
gtsam.PriorFactorPoint3(1, gtsam.Point3(1, 2, 1),
|
||||||
|
gtsam.noiseModel.Unit.Create(3))
|
||||||
|
]))
|
||||||
|
nlfg.add(gtsam.DecisionTreeFactor((10, 2), "1 3"))
|
||||||
|
values = gtsam.Values()
|
||||||
|
values.insert_point3(1, gtsam.Point3(0, 0, 0))
|
||||||
|
values.insert_point3(2, gtsam.Point3(2, 3, 1))
|
||||||
|
hfg = nlfg.linearize(values)
|
||||||
|
o = gtsam.Ordering()
|
||||||
|
o.push_back(1)
|
||||||
|
o.push_back(2)
|
||||||
|
o.push_back(10)
|
||||||
|
hbn = hfg.eliminateSequential(o)
|
||||||
|
hbv = hbn.optimize()
|
||||||
|
self.assertEqual(hbv.atDiscrete(10), 0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue