add doc strings to python unit test and add assertions

release/4.3a0
Varun Agrawal 2022-05-27 18:13:06 -04:00
parent 9d26a3dc9d
commit 3bde044248
4 changed files with 49 additions and 41 deletions

View File

@ -85,12 +85,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixtureConditional::print(const std::string &s, void GaussianMixtureConditional::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s << ": "; std::cout << s;
if (isContinuous()) std::cout << "Cont. "; if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Disc. "; if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybr. "; if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter); BaseConditional::print("", formatter);
std::cout << "Discrete Keys = "; std::cout << "\nDiscrete Keys = ";
for (auto &dk : discreteKeys()) { for (auto &dk : discreteKeys()) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
} }

View File

@ -67,30 +67,36 @@ HybridConditional::HybridConditional(
void HybridConditional::print(const std::string &s, void HybridConditional::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s; std::cout << s;
if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete()) std::cout << "Disc. "; if (inner_) {
if (isHybrid()) std::cout << "Hybr. "; inner_->print("", formatter);
std::cout << "P(";
size_t index = 0; } else {
const size_t N = keys().size(); if (isContinuous()) std::cout << "Continuous ";
const size_t contN = N - discreteKeys_.size(); if (isDiscrete()) std::cout << "Discrete ";
while (index < N) { if (isHybrid()) std::cout << "Hybrid ";
if (index > 0) { BaseConditional::print("", formatter);
if (index == nrFrontals_)
std::cout << " | "; std::cout << "P(";
else size_t index = 0;
std::cout << ", "; const size_t N = keys().size();
const size_t contN = N - discreteKeys_.size();
while (index < N) {
if (index > 0) {
if (index == nrFrontals_)
std::cout << " | ";
else
std::cout << ", ";
}
if (index < contN) {
std::cout << formatter(keys()[index]);
} else {
auto &dk = discreteKeys_[index - contN];
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
}
index++;
} }
if (index < contN) {
std::cout << formatter(keys()[index]);
} else {
auto &dk = discreteKeys_[index - contN];
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
}
index++;
} }
std::cout << ")\n";
if (inner_) inner_->print("", formatter);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -77,9 +77,9 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
void HybridFactor::print(const std::string &s, void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s; std::cout << s;
if (isContinuous_) std::cout << "Cont. "; if (isContinuous_) std::cout << "Continuous ";
if (isDiscrete_) std::cout << "Disc. "; if (isDiscrete_) std::cout << "Discrete ";
if (isHybrid_) std::cout << "Hybr. "; if (isHybrid_) std::cout << "Hybrid ";
this->printKeys("", formatter); this->printKeys("", formatter);
} }

View File

@ -16,24 +16,25 @@ import unittest
import gtsam import gtsam
import numpy as np import numpy as np
from gtsam.symbol_shorthand import X, C from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
class TestHybridFactorGraph(GtsamTestCase): class TestHybridFactorGraph(GtsamTestCase):
"""Unit tests for HybridFactorGraph."""
def test_create(self): def test_create(self):
"""Test contruction of hybrid factor graph."""
noiseModel = gtsam.noiseModel.Unit.Create(3) noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys() dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2)) dk.push_back((C(0), 2))
# print(dk.at(0))
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
noiseModel) noiseModel)
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
noiseModel) noiseModel)
gmf = gtsam.GaussianMixtureFactor.FromFactorList([X(0)], dk, gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2])
[jf1, jf2])
hfg = gtsam.HybridFactorGraph() hfg = gtsam.HybridFactorGraph()
hfg.add(jf1) hfg.add(jf1)
@ -41,16 +42,17 @@ class TestHybridFactorGraph(GtsamTestCase):
hfg.push_back(gmf) hfg.push_back(gmf)
hbn = hfg.eliminateSequential( hbn = hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph( gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(hfg, [C(0)]))
hfg, [C(0)]))
print("hbn = ", hbn) # print("hbn = ", hbn)
self.assertEqual(hbn.size(), 2)
mixture = hbn.at(0).getInner() mixture = hbn.at(0).inner()
print(mixture) self.assertIsInstance(mixture, gtsam.GaussianMixtureConditional)
self.assertEqual(len(mixture.keys()), 2)
discrete_conditional = hbn.at(hbn.size()-1).getInner() discrete_conditional = hbn.at(hbn.size() - 1).inner()
print(discrete_conditional) self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)
if __name__ == "__main__": if __name__ == "__main__":