break down tests to make reporting clearer

release/4.3a0
Varun Agrawal 2021-08-29 04:36:57 -04:00
parent 65837c1030
commit 289cb8f35b
1 changed files with 65 additions and 27 deletions

View File

@ -19,39 +19,77 @@ from gtsam.symbol_shorthand import B
class TestBasis(GtsamTestCase): class TestBasis(GtsamTestCase):
""" """
Tests Basis module python bindings. Tests FitBasis python binding for FourierBasis, Chebyshev1Basis, Chebyshev2Basis, and Chebyshev2.
"""
def test_fit_basis(self):
"""
Tests FitBasis python binding for FourierBasis, Chebyshev1Basis, Chebyshev2Basis, and
Chebyshev2.
It tests FitBasis by fitting to a ground-truth function that can be represented exactly in It tests FitBasis by fitting to a ground-truth function that can be represented exactly in
the basis, then checking that the regression (fit result) matches the function. For the the basis, then checking that the regression (fit result) matches the function. For the
Chebyshev bases, the line y=x is used to generate the data while for Fourier, 0.7*cos(x) is Chebyshev bases, the line y=x is used to generate the data while for Fourier, 0.7*cos(x) is
used. used.
""" """
f = lambda x: x # line y = x def setUp(self):
N = 2 self.N = 2
datax = [0., 0.5, 0.75] self.x = [0., 0.5, 0.75]
interpx = np.linspace(0., 1., 10) self.interpx = np.linspace(0., 1., 10)
noise = gtsam.noiseModel.Unit.Create(1) self.noise = gtsam.noiseModel.Unit.Create(1)
def evaluate(basis, fitparams, x): def evaluate(self, basis, fitparams, x):
# until wrapper for Basis functors are ready, this is how to evaluate a basis function. """
return basis.WeightMatrix(N, x) @ fitparams Until wrapper for Basis functors are ready,
this is how to evaluate a basis function.
"""
return basis.WeightMatrix(self.N, x) @ fitparams
def testBasis(fitter, basis, f=f): def fit_basis_helper(self, fitter, basis, f=lambda x: x):
# test a basis by checking that the fit result matches the function at x-values interpx. """Helper method to fit data to a specified fitter using a specified basis."""
data = {x: f(x) for x in datax} data = {x: f(x) for x in self.x}
fit = fitter(data, noise, N) fit = fitter(data, self.noise, self.N)
coeff = fit.parameters() coeff = fit.parameters()
interpy = evaluate(basis, coeff, interpx) interpy = self.evaluate(basis, coeff, self.interpx)
np.testing.assert_almost_equal(interpy, np.array([f(x) for x in interpx]), decimal=7) return interpy
testBasis(gtsam.FitBasisFourierBasis, gtsam.FourierBasis, f=lambda x: 0.7 * np.cos(x)) def test_fit_basis_fourier(self):
testBasis(gtsam.FitBasisChebyshev1Basis, gtsam.Chebyshev1Basis) """Fit a Fourier basis."""
testBasis(gtsam.FitBasisChebyshev2Basis, gtsam.Chebyshev2Basis)
testBasis(gtsam.FitBasisChebyshev2, gtsam.Chebyshev2) f = lambda x: 0.7 * np.cos(x)
interpy = self.fit_basis_helper(gtsam.FitBasisFourierBasis,
gtsam.FourierBasis, f)
# test a basis by checking that the fit result matches the function at x-values interpx.
np.testing.assert_almost_equal(interpy,
np.array([f(x) for x in self.interpx]),
decimal=7)
def test_fit_basis_chebyshev1basis(self):
"""Fit a Chebyshev1 basis."""
f = lambda x: x
interpy = self.fit_basis_helper(gtsam.FitBasisChebyshev1Basis,
gtsam.Chebyshev1Basis, f)
# test a basis by checking that the fit result matches the function at x-values interpx.
np.testing.assert_almost_equal(interpy,
np.array([f(x) for x in self.interpx]),
decimal=7)
def test_fit_basis_chebyshev2basis(self):
"""Fit a Chebyshev2 basis."""
f = lambda x: x
interpy = self.fit_basis_helper(gtsam.FitBasisChebyshev2Basis,
gtsam.Chebyshev2Basis)
# test a basis by checking that the fit result matches the function at x-values interpx.
np.testing.assert_almost_equal(interpy,
np.array([f(x) for x in self.interpx]),
decimal=7)
def test_fit_basis_chebyshev2(self):
"""Fit a Chebyshev2 pseudospectral basis."""
f = lambda x: x
interpy = self.fit_basis_helper(gtsam.FitBasisChebyshev2,
gtsam.Chebyshev2)
# test a basis by checking that the fit result matches the function at x-values interpx.
np.testing.assert_almost_equal(interpy,
np.array([f(x) for x in self.interpx]),
decimal=7)
if __name__ == "__main__": if __name__ == "__main__":