add doc strings to python unit test and add assertions

release/4.3a0
Varun Agrawal 2022-05-27 18:13:06 -04:00
parent 9d26a3dc9d
commit 3bde044248
4 changed files with 49 additions and 41 deletions

View File

@ -85,12 +85,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
/* *******************************************************************************/
void GaussianMixtureConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s << ": ";
if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete()) std::cout << "Disc. ";
if (isHybrid()) std::cout << "Hybr. ";
std::cout << s;
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << "Discrete Keys = ";
std::cout << "\nDiscrete Keys = ";
for (auto &dk : discreteKeys()) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}

View File

@ -67,30 +67,36 @@ HybridConditional::HybridConditional(
void HybridConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete()) std::cout << "Disc. ";
if (isHybrid()) std::cout << "Hybr. ";
std::cout << "P(";
size_t index = 0;
const size_t N = keys().size();
const size_t contN = N - discreteKeys_.size();
while (index < N) {
if (index > 0) {
if (index == nrFrontals_)
std::cout << " | ";
else
std::cout << ", ";
if (inner_) {
inner_->print("", formatter);
} else {
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << "P(";
size_t index = 0;
const size_t N = keys().size();
const size_t contN = N - discreteKeys_.size();
while (index < N) {
if (index > 0) {
if (index == nrFrontals_)
std::cout << " | ";
else
std::cout << ", ";
}
if (index < contN) {
std::cout << formatter(keys()[index]);
} else {
auto &dk = discreteKeys_[index - contN];
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
}
index++;
}
if (index < contN) {
std::cout << formatter(keys()[index]);
} else {
auto &dk = discreteKeys_[index - contN];
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
}
index++;
}
std::cout << ")\n";
if (inner_) inner_->print("", formatter);
}
/* ************************************************************************ */

View File

@ -77,9 +77,9 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. ";
if (isContinuous_) std::cout << "Continuous ";
if (isDiscrete_) std::cout << "Discrete ";
if (isHybrid_) std::cout << "Hybrid ";
this->printKeys("", formatter);
}

View File

@ -16,24 +16,25 @@ import unittest
import gtsam
import numpy as np
from gtsam.symbol_shorthand import X, C
from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase
class TestHybridFactorGraph(GtsamTestCase):
"""Unit tests for HybridFactorGraph."""
def test_create(self):
"""Test contruction of hybrid factor graph."""
noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2))
# print(dk.at(0))
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
noiseModel)
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
noiseModel)
gmf = gtsam.GaussianMixtureFactor.FromFactorList([X(0)], dk,
[jf1, jf2])
gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2])
hfg = gtsam.HybridFactorGraph()
hfg.add(jf1)
@ -41,16 +42,17 @@ class TestHybridFactorGraph(GtsamTestCase):
hfg.push_back(gmf)
hbn = hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(
hfg, [C(0)]))
gtsam.Ordering.ColamdConstrainedLastHybridFactorGraph(hfg, [C(0)]))
print("hbn = ", hbn)
# print("hbn = ", hbn)
self.assertEqual(hbn.size(), 2)
mixture = hbn.at(0).getInner()
print(mixture)
mixture = hbn.at(0).inner()
self.assertIsInstance(mixture, gtsam.GaussianMixtureConditional)
self.assertEqual(len(mixture.keys()), 2)
discrete_conditional = hbn.at(hbn.size()-1).getInner()
print(discrete_conditional)
discrete_conditional = hbn.at(hbn.size() - 1).inner()
self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)
if __name__ == "__main__":