add doc strings to python unit test and add assertions
parent
9d26a3dc9d
commit
3bde044248
|
@ -85,12 +85,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void GaussianMixtureConditional::print(const std::string &s,
|
void GaussianMixtureConditional::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s << ": ";
|
std::cout << s;
|
||||||
if (isContinuous()) std::cout << "Cont. ";
|
if (isContinuous()) std::cout << "Continuous ";
|
||||||
if (isDiscrete()) std::cout << "Disc. ";
|
if (isDiscrete()) std::cout << "Discrete ";
|
||||||
if (isHybrid()) std::cout << "Hybr. ";
|
if (isHybrid()) std::cout << "Hybrid ";
|
||||||
BaseConditional::print("", formatter);
|
BaseConditional::print("", formatter);
|
||||||
std::cout << "Discrete Keys = ";
|
std::cout << "\nDiscrete Keys = ";
|
||||||
for (auto &dk : discreteKeys()) {
|
for (auto &dk : discreteKeys()) {
|
||||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,30 +67,36 @@ HybridConditional::HybridConditional(
|
||||||
void HybridConditional::print(const std::string &s,
|
void HybridConditional::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s;
|
std::cout << s;
|
||||||
if (isContinuous()) std::cout << "Cont. ";
|
|
||||||
if (isDiscrete()) std::cout << "Disc. ";
|
if (inner_) {
|
||||||
if (isHybrid()) std::cout << "Hybr. ";
|
inner_->print("", formatter);
|
||||||
std::cout << "P(";
|
|
||||||
size_t index = 0;
|
} else {
|
||||||
const size_t N = keys().size();
|
if (isContinuous()) std::cout << "Continuous ";
|
||||||
const size_t contN = N - discreteKeys_.size();
|
if (isDiscrete()) std::cout << "Discrete ";
|
||||||
while (index < N) {
|
if (isHybrid()) std::cout << "Hybrid ";
|
||||||
if (index > 0) {
|
BaseConditional::print("", formatter);
|
||||||
if (index == nrFrontals_)
|
|
||||||
std::cout << " | ";
|
std::cout << "P(";
|
||||||
else
|
size_t index = 0;
|
||||||
std::cout << ", ";
|
const size_t N = keys().size();
|
||||||
|
const size_t contN = N - discreteKeys_.size();
|
||||||
|
while (index < N) {
|
||||||
|
if (index > 0) {
|
||||||
|
if (index == nrFrontals_)
|
||||||
|
std::cout << " | ";
|
||||||
|
else
|
||||||
|
std::cout << ", ";
|
||||||
|
}
|
||||||
|
if (index < contN) {
|
||||||
|
std::cout << formatter(keys()[index]);
|
||||||
|
} else {
|
||||||
|
auto &dk = discreteKeys_[index - contN];
|
||||||
|
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
|
||||||
|
}
|
||||||
|
index++;
|
||||||
}
|
}
|
||||||
if (index < contN) {
|
|
||||||
std::cout << formatter(keys()[index]);
|
|
||||||
} else {
|
|
||||||
auto &dk = discreteKeys_[index - contN];
|
|
||||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
|
|
||||||
}
|
|
||||||
index++;
|
|
||||||
}
|
}
|
||||||
std::cout << ")\n";
|
|
||||||
if (inner_) inner_->print("", formatter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
@ -77,9 +77,9 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
void HybridFactor::print(const std::string &s,
|
void HybridFactor::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s;
|
std::cout << s;
|
||||||
if (isContinuous_) std::cout << "Cont. ";
|
if (isContinuous_) std::cout << "Continuous ";
|
||||||
if (isDiscrete_) std::cout << "Disc. ";
|
if (isDiscrete_) std::cout << "Discrete ";
|
||||||
if (isHybrid_) std::cout << "Hybr. ";
|
if (isHybrid_) std::cout << "Hybrid ";
|
||||||
this->printKeys("", formatter);
|
this->printKeys("", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,24 +16,25 @@ import unittest
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam.symbol_shorthand import X, C
|
from gtsam.symbol_shorthand import C, X
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestHybridFactorGraph(GtsamTestCase):
|
class TestHybridFactorGraph(GtsamTestCase):
|
||||||
|
"""Unit tests for HybridFactorGraph."""
|
||||||
|
|
||||||
def test_create(self):
|
def test_create(self):
|
||||||
|
"""Test contruction of hybrid factor graph."""
|
||||||
noiseModel = gtsam.noiseModel.Unit.Create(3)
|
noiseModel = gtsam.noiseModel.Unit.Create(3)
|
||||||
dk = gtsam.DiscreteKeys()
|
dk = gtsam.DiscreteKeys()
|
||||||
dk.push_back((C(0), 2))
|
dk.push_back((C(0), 2))
|
||||||
|
|
||||||
# print(dk.at(0))
|
|
||||||
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
|
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
|
||||||
noiseModel)
|
noiseModel)
|
||||||
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
|
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
|
||||||
noiseModel)
|
noiseModel)
|
||||||
|
|
||||||
gmf = gtsam.GaussianMixtureFactor.FromFactorList([X(0)], dk,
|
gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2])
|
||||||
[jf1, jf2])
|
|
||||||
|
|
||||||
hfg = gtsam.HybridFactorGraph()
|
hfg = gtsam.HybridFactorGraph()
|
||||||
hfg.add(jf1)
|
hfg.add(jf1)
|
||||||
|
@ -41,16 +42,17 @@ class TestHybridFactorGraph(GtsamTestCase):
|
||||||
hfg.push_back(gmf)
|
hfg.push_back(gmf)
|
||||||
|
|
||||||
hbn = hfg.eliminateSequential(
|
hbn = hfg.eliminateSequential(
|
||||||
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(
|
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(hfg, [C(0)]))
|
||||||
hfg, [C(0)]))
|
|
||||||
|
|
||||||
print("hbn = ", hbn)
|
# print("hbn = ", hbn)
|
||||||
|
self.assertEqual(hbn.size(), 2)
|
||||||
|
|
||||||
mixture = hbn.at(0).getInner()
|
mixture = hbn.at(0).inner()
|
||||||
print(mixture)
|
self.assertIsInstance(mixture, gtsam.GaussianMixtureConditional)
|
||||||
|
self.assertEqual(len(mixture.keys()), 2)
|
||||||
|
|
||||||
discrete_conditional = hbn.at(hbn.size()-1).getInner()
|
discrete_conditional = hbn.at(hbn.size() - 1).inner()
|
||||||
print(discrete_conditional)
|
self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue