From a5fd9c120b3159812c196836f82e537cf6b43c07 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 10 Jul 2023 12:54:32 -0400 Subject: [PATCH] fix chi_squared_quantile --- gtsam/nonlinear/GncHelpers.h | 168 +++++++++++------------ gtsam/nonlinear/tests/testGncHelpers.cpp | 37 +++++ 2 files changed, 121 insertions(+), 84 deletions(-) create mode 100644 gtsam/nonlinear/tests/testGncHelpers.cpp diff --git a/gtsam/nonlinear/GncHelpers.h b/gtsam/nonlinear/GncHelpers.h index 2dac9ac5e..7a27a3530 100644 --- a/gtsam/nonlinear/GncHelpers.h +++ b/gtsam/nonlinear/GncHelpers.h @@ -29,9 +29,11 @@ template using return_t = typename std::conditional::value, double, T>::type; +/// Get common type amongst all arguments template using common_t = typename std::common_type::type; +/// Helper template for finding common return type template using common_return_t = return_t>; @@ -126,16 +128,14 @@ template constexpr T incomplete_gamma_quad_recur(const T lb, const T ub, const T a, const T lg_term, const int counter) noexcept { - return (counter < 49 ? // if - incomplete_gamma_quad_fn( - incomplete_gamma_quad_inp_vals(lb, ub, counter), a, lg_term) * - incomplete_gamma_quad_weight_vals(lb, ub, counter) + - incomplete_gamma_quad_recur(lb, ub, a, lg_term, counter + 1) - : - // else - incomplete_gamma_quad_fn( - incomplete_gamma_quad_inp_vals(lb, ub, counter), a, lg_term) * - incomplete_gamma_quad_weight_vals(lb, ub, counter)); + T val = incomplete_gamma_quad_fn( + incomplete_gamma_quad_inp_vals(lb, ub, counter), a, lg_term) * + incomplete_gamma_quad_weight_vals(lb, ub, counter); + if (counter < 49) { + return val + incomplete_gamma_quad_recur(lb, ub, a, lg_term, counter + 1); + } else { + return val; + } } template @@ -180,10 +180,13 @@ constexpr T incomplete_gamma_quad(const T a, const T z) noexcept { template constexpr T incomplete_gamma_cf_2_recur(const T a, const T z, const int depth) noexcept { - return (depth < 100 ? (1 + (depth - 1) * 2 - a + z) + - depth * (a - depth) / - incomplete_gamma_cf_2_recur(a, z, depth + 1) - : (1 + (depth - 1) * 2 - a + z)); + T val = 1 + (depth - 1) * 2 - a + z; + if (depth < 100) { + return val + + depth * (a - depth) / incomplete_gamma_cf_2_recur(a, z, depth + 1); + } else { + return val; + } } template @@ -200,50 +203,49 @@ constexpr T incomplete_gamma_cf_2( template constexpr T incomplete_gamma_cf_1_coef(const T a, const T z, const int depth) noexcept { - return (is_odd(depth) ? -(a - 1 + T(depth + 1) / T(2)) * z - : T(depth) / T(2) * z); + if (is_odd(depth)) { + return -(a - 1 + T(depth + 1) / T(2)) * z; + } else { + return T(depth) / T(2) * z; + } } template constexpr T incomplete_gamma_cf_1_recur(const T a, const T z, const int depth) noexcept { - return (depth < 55 ? // if - (a + depth - 1) + incomplete_gamma_cf_1_coef(a, z, depth) / - incomplete_gamma_cf_1_recur(a, z, depth + 1) - : - // else - (a + depth - 1)); + T val = a + depth - 1; + if (depth < 55) { + return val + incomplete_gamma_cf_1_coef(a, z, depth) / + incomplete_gamma_cf_1_recur(a, z, depth + 1); + } else { + return val; + } } +/// lower (regularized) incomplete gamma function template -constexpr T incomplete_gamma_cf_1( - const T a, - const T z) noexcept { // lower (regularized) incomplete gamma function - return (exp(a * log(z) - z - lgamma(a)) / - incomplete_gamma_cf_1_recur(a, z, 1)); +constexpr T incomplete_gamma_cf_1(const T a, const T z) noexcept { + return exp(a * log(z) - z - lgamma(a)) / incomplete_gamma_cf_1_recur(a, z, 1); } -// - +/// Perform NaN check template constexpr T incomplete_gamma_check(const T a, const T z) noexcept { - return ( // NaN check - (is_nan(a) || is_nan(z)) ? LIM::quiet_NaN() : - // - a < T(0) ? LIM::quiet_NaN() - : - // - LIM::min() > z ? T(0) - : - // - LIM::min() > a ? T(1) - : - // cf or quadrature - (a < T(10)) && (z - a < T(10)) ? incomplete_gamma_cf_1(a, z) - : (a < T(10)) || (z / a > T(3)) ? incomplete_gamma_cf_2(a, z) - : - // else - incomplete_gamma_quad(a, z)); + if (is_nan(a) || is_nan(z)) { + return LIM::quiet_NaN(); + } else if (a < T(0)) { + return LIM::quiet_NaN(); + } else if (LIM::min() > z) { + return T(0); + } else if (LIM::min() > a) { + return T(1); + } else if (a < T(10) && z - a < T(10)) { // cf or quadrature + return incomplete_gamma_cf_1(a, z); + } else if (a < T(10) || z / a > T(3)) { + return incomplete_gamma_cf_2(a, z); + } else { + return incomplete_gamma_quad(a, z); + } } template > @@ -323,24 +325,24 @@ constexpr T incomplete_gamma_inv_initial_val_1( template constexpr T incomplete_gamma_inv_initial_val_2( const T a, const T p, const T t_val) noexcept { // a <= 1.0 - return (p < t_val ? // if - pow(p / t_val, T(1) / a) - : - // else - T(1) - log(T(1) - (p - t_val) / (T(1) - t_val))); + if (p < t_val) { + return pow(p / t_val, T(1) / a); + } else { + return T(1) - log(T(1) - (p - t_val) / (T(1) - t_val)); + } } -// initial value +// Initial value template constexpr T incomplete_gamma_inv_initial_val(const T a, const T p) noexcept { - return (a > T(1) ? // if - incomplete_gamma_inv_initial_val_1( - a, incomplete_gamma_inv_t_val_1(p), p > T(0.5) ? T(-1) : T(1)) - : - // else - incomplete_gamma_inv_initial_val_2( - a, p, incomplete_gamma_inv_t_val_2(a))); + if (a > T(1)) { + return incomplete_gamma_inv_initial_val_1( + a, incomplete_gamma_inv_t_val_1(p), p > T(0.5) ? T(-1) : T(1)); + } else { + return incomplete_gamma_inv_initial_val_2(a, p, + incomplete_gamma_inv_t_val_2(a)); + } } // @@ -405,18 +407,15 @@ template constexpr T incomplete_gamma_inv_decision(const T value, const T a, const T p, const T direc, const T lg_val, const int iter_count) noexcept { -// return( abs(direc) > GCEM_INCML_GAMMA_INV_TOL ? -// incomplete_gamma_inv_recur(value - direc, a, p, -// incomplete_gamma_inv_deriv_1(value,a,lg_val), lg_val) : value - direc ); -#define INCML_GAMMA_INV_MAX_ITER 35 - return (iter_count <= INCML_GAMMA_INV_MAX_ITER ? // if - incomplete_gamma_inv_recur( - value - direc, a, p, - incomplete_gamma_inv_deriv_1(value, a, lg_val), lg_val, - iter_count + 1) - : - // else - value - direc); + constexpr int INCML_GAMMA_INV_MAX_ITER = 35; + + if (iter_count <= INCML_GAMMA_INV_MAX_ITER) { + return incomplete_gamma_inv_recur( + value - direc, a, p, incomplete_gamma_inv_deriv_1(value, a, lg_val), + lg_val, iter_count + 1); + } else { + return value - direc; + } } template @@ -429,19 +428,20 @@ constexpr T incomplete_gamma_inv_begin(const T initial_val, const T a, template constexpr T incomplete_gamma_inv_check(const T a, const T p) noexcept { - return ( // NaN check - (is_nan(a) || is_nan(p)) ? LIM::quiet_NaN() : - // - LIM::min() > p ? T(0) - : p > T(1) ? LIM::quiet_NaN() - : LIM::min() > abs(T(1) - p) ? LIM::infinity() - : - // - LIM::min() > a ? T(0) - : - // else - incomplete_gamma_inv_begin(incomplete_gamma_inv_initial_val(a, p), a, - p, lgamma(a))); + if (is_nan(a) || is_nan(p)) { + return LIM::quiet_NaN(); + } else if (LIM::min() > p) { + return T(0); + } else if (p > T(1)) { + return LIM::quiet_NaN(); + } else if (LIM::min() > fabs(T(1) - p)) { + return LIM::infinity(); + } else if (LIM::min() > a) { + return T(0); + } else { + return incomplete_gamma_inv_begin(incomplete_gamma_inv_initial_val(a, p), a, + p, lgamma(a)); + } } template > diff --git a/gtsam/nonlinear/tests/testGncHelpers.cpp b/gtsam/nonlinear/tests/testGncHelpers.cpp new file mode 100644 index 000000000..6e47f97cc --- /dev/null +++ b/gtsam/nonlinear/tests/testGncHelpers.cpp @@ -0,0 +1,37 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testGncHelpers.cpp + * @date July 10, 2023 + * @author Varun Agrawal + * @brief Tests for Chi-squared distribution. + */ + +#include +#include +#include + +using namespace gtsam; + +/* ************************************************************************* */ +TEST(GncHelpers, ChiSqInv) { + double barcSq = chi_squared_quantile(2, 0.99); + EXPECT_DOUBLES_EQUAL(9.21034, barcSq, 1e-5); +} + +/* ************************************************************************* */ +int main() { + srand(time(nullptr)); + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */