diff --git a/python/gtsam/tests/test_basis.py b/python/gtsam/tests/test_basis.py index 3af0a87f1..5d3c5ace3 100644 --- a/python/gtsam/tests/test_basis.py +++ b/python/gtsam/tests/test_basis.py @@ -19,39 +19,77 @@ from gtsam.symbol_shorthand import B class TestBasis(GtsamTestCase): """ - Tests Basis module python bindings. + Tests FitBasis python binding for FourierBasis, Chebyshev1Basis, Chebyshev2Basis, and Chebyshev2. + + It tests FitBasis by fitting to a ground-truth function that can be represented exactly in + the basis, then checking that the regression (fit result) matches the function. For the + Chebyshev bases, the line y=x is used to generate the data while for Fourier, 0.7*cos(x) is + used. """ - def test_fit_basis(self): + def setUp(self): + self.N = 2 + self.x = [0., 0.5, 0.75] + self.interpx = np.linspace(0., 1., 10) + self.noise = gtsam.noiseModel.Unit.Create(1) + + def evaluate(self, basis, fitparams, x): """ - Tests FitBasis python binding for FourierBasis, Chebyshev1Basis, Chebyshev2Basis, and - Chebyshev2. - It tests FitBasis by fitting to a ground-truth function that can be represented exactly in - the basis, then checking that the regression (fit result) matches the function. For the - Chebyshev bases, the line y=x is used to generate the data while for Fourier, 0.7*cos(x) is - used. + Until wrapper for Basis functors are ready, + this is how to evaluate a basis function. """ - f = lambda x: x # line y = x - N = 2 - datax = [0., 0.5, 0.75] - interpx = np.linspace(0., 1., 10) - noise = gtsam.noiseModel.Unit.Create(1) + return basis.WeightMatrix(self.N, x) @ fitparams - def evaluate(basis, fitparams, x): - # until wrapper for Basis functors are ready, this is how to evaluate a basis function. - return basis.WeightMatrix(N, x) @ fitparams + def fit_basis_helper(self, fitter, basis, f=lambda x: x): + """Helper method to fit data to a specified fitter using a specified basis.""" + data = {x: f(x) for x in self.x} + fit = fitter(data, self.noise, self.N) + coeff = fit.parameters() + interpy = self.evaluate(basis, coeff, self.interpx) + return interpy - def testBasis(fitter, basis, f=f): - # test a basis by checking that the fit result matches the function at x-values interpx. - data = {x: f(x) for x in datax} - fit = fitter(data, noise, N) - coeff = fit.parameters() - interpy = evaluate(basis, coeff, interpx) - np.testing.assert_almost_equal(interpy, np.array([f(x) for x in interpx]), decimal=7) + def test_fit_basis_fourier(self): + """Fit a Fourier basis.""" - testBasis(gtsam.FitBasisFourierBasis, gtsam.FourierBasis, f=lambda x: 0.7 * np.cos(x)) - testBasis(gtsam.FitBasisChebyshev1Basis, gtsam.Chebyshev1Basis) - testBasis(gtsam.FitBasisChebyshev2Basis, gtsam.Chebyshev2Basis) - testBasis(gtsam.FitBasisChebyshev2, gtsam.Chebyshev2) + f = lambda x: 0.7 * np.cos(x) + interpy = self.fit_basis_helper(gtsam.FitBasisFourierBasis, + gtsam.FourierBasis, f) + # test a basis by checking that the fit result matches the function at x-values interpx. + np.testing.assert_almost_equal(interpy, + np.array([f(x) for x in self.interpx]), + decimal=7) + + def test_fit_basis_chebyshev1basis(self): + """Fit a Chebyshev1 basis.""" + + f = lambda x: x + interpy = self.fit_basis_helper(gtsam.FitBasisChebyshev1Basis, + gtsam.Chebyshev1Basis, f) + # test a basis by checking that the fit result matches the function at x-values interpx. + np.testing.assert_almost_equal(interpy, + np.array([f(x) for x in self.interpx]), + decimal=7) + + def test_fit_basis_chebyshev2basis(self): + """Fit a Chebyshev2 basis.""" + + f = lambda x: x + interpy = self.fit_basis_helper(gtsam.FitBasisChebyshev2Basis, + gtsam.Chebyshev2Basis) + # test a basis by checking that the fit result matches the function at x-values interpx. + np.testing.assert_almost_equal(interpy, + np.array([f(x) for x in self.interpx]), + decimal=7) + + def test_fit_basis_chebyshev2(self): + """Fit a Chebyshev2 pseudospectral basis.""" + + f = lambda x: x + interpy = self.fit_basis_helper(gtsam.FitBasisChebyshev2, + gtsam.Chebyshev2) + # test a basis by checking that the fit result matches the function at x-values interpx. + np.testing.assert_almost_equal(interpy, + np.array([f(x) for x in self.interpx]), + decimal=7) if __name__ == "__main__":