From 8201c77b30b3d69d0a9b775b8eb164be9650ae3c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 10 May 2023 15:37:46 -0400 Subject: [PATCH] refactor IncompleteGamma class --- gtsam/nonlinear/GncHelpers.h | 267 +++++++++++++++++------------------ 1 file changed, 131 insertions(+), 136 deletions(-) diff --git a/gtsam/nonlinear/GncHelpers.h b/gtsam/nonlinear/GncHelpers.h index 399da2c99..38185249c 100644 --- a/gtsam/nonlinear/GncHelpers.h +++ b/gtsam/nonlinear/GncHelpers.h @@ -103,154 +103,147 @@ static const long double gauss_legendre_50_weights[50] = { namespace internal { -/// 50 point Gauss-Legendre quadrature template -constexpr T incomplete_gamma_quad_inp_vals(const T lb, const T ub, - const int counter) noexcept { - return (ub - lb) * gauss_legendre_50_points[counter] / T(2) + - (ub + lb) / T(2); -} +class IncompleteGamma { + /// 50 point Gauss-Legendre quadrature + static constexpr T quadrature_inp_vals(const T lb, const T ub, + const int counter) noexcept { + return (ub - lb) * gauss_legendre_50_points[counter] / T(2) + + (ub + lb) / T(2); + } -template -constexpr T incomplete_gamma_quad_weight_vals(const T lb, const T ub, - const int counter) noexcept { - return (ub - lb) * gauss_legendre_50_weights[counter] / T(2); -} + static constexpr T quadrature_weight_vals(const T lb, const T ub, + const int counter) noexcept { + return (ub - lb) * gauss_legendre_50_weights[counter] / T(2); + } -template -constexpr T incomplete_gamma_quad_fn(const T x, const T a, - const T lg_term) noexcept { - return exp(-x + (a - T(1)) * log(x) - lg_term); -} + static constexpr T quadrature_fn(const T x, const T a, + const T lg_term) noexcept { + return exp(-x + (a - T(1)) * log(x) - lg_term); + } -template -constexpr T incomplete_gamma_quad_recur(const T lb, const T ub, const T a, - const T lg_term, - const int counter) noexcept { - return (counter < 49 ? // if - incomplete_gamma_quad_fn( - incomplete_gamma_quad_inp_vals(lb, ub, counter), a, lg_term) * - incomplete_gamma_quad_weight_vals(lb, ub, counter) + - incomplete_gamma_quad_recur(lb, ub, a, lg_term, counter + 1) - : - // else - incomplete_gamma_quad_fn( - incomplete_gamma_quad_inp_vals(lb, ub, counter), a, lg_term) * - incomplete_gamma_quad_weight_vals(lb, ub, counter)); -} + static constexpr T quadrature_recur(const T lb, const T ub, const T a, + const T lg_term, + const int counter) noexcept { + if (counter < 49) { + return quadrature_fn(quadrature_inp_vals(lb, ub, counter), a, lg_term) * + quadrature_weight_vals(lb, ub, counter) + + quadrature_recur(lb, ub, a, lg_term, counter + 1); + } else { + return quadrature_fn(quadrature_inp_vals(lb, ub, counter), a, lg_term) * + quadrature_weight_vals(lb, ub, counter); + } + } -template -constexpr T incomplete_gamma_quad_lb(const T a, const T z) noexcept { - // break integration into ranges - return (a > T(1000) ? std::max(T(0), std::min(z, a) - 11 * sqrt(a)) - : a > T(800) ? std::max(T(0), std::min(z, a) - 11 * sqrt(a)) - : a > T(500) ? std::max(T(0), std::min(z, a) - 10 * sqrt(a)) - : a > T(300) ? std::max(T(0), std::min(z, a) - 10 * sqrt(a)) - : a > T(100) ? std::max(T(0), std::min(z, a) - 9 * sqrt(a)) - : a > T(90) ? std::max(T(0), std::min(z, a) - 9 * sqrt(a)) - : a > T(70) ? std::max(T(0), std::min(z, a) - 8 * sqrt(a)) - : a > T(50) ? std::max(T(0), std::min(z, a) - 7 * sqrt(a)) - : a > T(40) ? std::max(T(0), std::min(z, a) - 6 * sqrt(a)) - : a > T(30) ? std::max(T(0), std::min(z, a) - 5 * sqrt(a)) - : std::max(T(0), std::min(z, a) - 4 * sqrt(a))); -} + static constexpr T quadrature_lb(const T a, const T z) noexcept { + // break integration into ranges + return a > T(1000) ? std::max(T(0), std::min(z, a) - 11 * sqrt(a)) + : a > T(800) ? std::max(T(0), std::min(z, a) - 11 * sqrt(a)) + : a > T(500) ? std::max(T(0), std::min(z, a) - 10 * sqrt(a)) + : a > T(300) ? std::max(T(0), std::min(z, a) - 10 * sqrt(a)) + : a > T(100) ? std::max(T(0), std::min(z, a) - 9 * sqrt(a)) + : a > T(90) ? std::max(T(0), std::min(z, a) - 9 * sqrt(a)) + : a > T(70) ? std::max(T(0), std::min(z, a) - 8 * sqrt(a)) + : a > T(50) ? std::max(T(0), std::min(z, a) - 7 * sqrt(a)) + : a > T(40) ? std::max(T(0), std::min(z, a) - 6 * sqrt(a)) + : a > T(30) ? std::max(T(0), std::min(z, a) - 5 * sqrt(a)) + : std::max(T(0), std::min(z, a) - 4 * sqrt(a)); + } -template -constexpr T incomplete_gamma_quad_ub(const T a, const T z) noexcept { - return (a > T(1000) ? std::min(z, a + 10 * sqrt(a)) - : a > T(800) ? std::min(z, a + 10 * sqrt(a)) - : a > T(500) ? std::min(z, a + 9 * sqrt(a)) - : a > T(300) ? std::min(z, a + 9 * sqrt(a)) - : a > T(100) ? std::min(z, a + 8 * sqrt(a)) - : a > T(90) ? std::min(z, a + 8 * sqrt(a)) - : a > T(70) ? std::min(z, a + 7 * sqrt(a)) - : a > T(50) ? std::min(z, a + 6 * sqrt(a)) - : std::min(z, a + 5 * sqrt(a))); -} + static constexpr T quadrature_ub(const T a, const T z) noexcept { + return a > T(1000) ? std::min(z, a + 10 * sqrt(a)) + : a > T(800) ? std::min(z, a + 10 * sqrt(a)) + : a > T(500) ? std::min(z, a + 9 * sqrt(a)) + : a > T(300) ? std::min(z, a + 9 * sqrt(a)) + : a > T(100) ? std::min(z, a + 8 * sqrt(a)) + : a > T(90) ? std::min(z, a + 8 * sqrt(a)) + : a > T(70) ? std::min(z, a + 7 * sqrt(a)) + : a > T(50) ? std::min(z, a + 6 * sqrt(a)) + : std::min(z, a + 5 * sqrt(a)); + } -template -constexpr T incomplete_gamma_quad(const T a, const T z) noexcept { - return incomplete_gamma_quad_recur(incomplete_gamma_quad_lb(a, z), - incomplete_gamma_quad_ub(a, z), a, - lgamma(a), 0); -} + static constexpr T quadrature(const T a, const T z) noexcept { + return quadrature_recur(quadrature_lb(a, z), quadrature_ub(a, z), a, + lgamma(a), 0); + } -// reverse cf expansion -// see: https://functions.wolfram.com/GammaBetaErf/Gamma2/10/0003/ + // reverse cf expansion + // see: https://functions.wolfram.com/GammaBetaErf/Gamma2/10/0003/ + static constexpr T cf_2_recur(const T a, const T z, + const int depth) noexcept { + if (depth < 100) { + return (1 + (depth - 1) * 2 - a + z) + + depth * (a - depth) / cf_2_recur(a, z, depth + 1); + } else { + return 1 + (depth - 1) * 2 - a + z; + } + } -template -constexpr T incomplete_gamma_cf_2_recur(const T a, const T z, - const int depth) noexcept { - return (depth < 100 ? (1 + (depth - 1) * 2 - a + z) + - depth * (a - depth) / - incomplete_gamma_cf_2_recur(a, z, depth + 1) - : (1 + (depth - 1) * 2 - a + z)); -} + /** + * @brief Lower (regularized) incomplete gamma function + * + * @param a + * @param z + * @return constexpr T + */ + static constexpr T cf_2(const T a, const T z) noexcept { + return T(1.0) - exp(a * log(z) - z - lgamma(a)) / cf_2_recur(a, z, 1); + } -template -constexpr T incomplete_gamma_cf_2( - const T a, - const T z) noexcept { // lower (regularized) incomplete gamma function - return (T(1.0) - exp(a * log(z) - z - lgamma(a)) / - incomplete_gamma_cf_2_recur(a, z, 1)); -} + // continued fraction expansion + // see: http://functions.wolfram.com/GammaBetaErf/Gamma2/10/0009/ + static constexpr T cf_1_coef(const T a, const T z, const int depth) noexcept { + return (is_odd(depth) ? -(a - 1 + T(depth + 1) / T(2)) * z + : T(depth) / T(2) * z); + } -// cf expansion -// see: http://functions.wolfram.com/GammaBetaErf/Gamma2/10/0009/ + static constexpr T cf_1_recur(const T a, const T z, + const int depth) noexcept { + if (depth < 55) { + return (a + depth - 1) + + cf_1_coef(a, z, depth) / cf_1_recur(a, z, depth + 1); + } else { + return (a + depth - 1); + } + } -template -constexpr T incomplete_gamma_cf_1_coef(const T a, const T z, - const int depth) noexcept { - return (is_odd(depth) ? -(a - 1 + T(depth + 1) / T(2)) * z - : T(depth) / T(2) * z); -} + /** + * @brief Lower (regularized) incomplete gamma function + * + * @param a + * @param z + * @return constexpr T + */ + static constexpr T cf_1(const T a, const T z) noexcept { + return exp(a * log(z) - z - lgamma(a)) / cf_1_recur(a, z, 1); + } -template -constexpr T incomplete_gamma_cf_1_recur(const T a, const T z, - const int depth) noexcept { - return (depth < 55 ? // if - (a + depth - 1) + incomplete_gamma_cf_1_coef(a, z, depth) / - incomplete_gamma_cf_1_recur(a, z, depth + 1) - : - // else - (a + depth - 1)); -} - -template -constexpr T incomplete_gamma_cf_1( - const T a, - const T z) noexcept { // lower (regularized) incomplete gamma function - return (exp(a * log(z) - z - lgamma(a)) / - incomplete_gamma_cf_1_recur(a, z, 1)); -} - -// - -template -constexpr T incomplete_gamma_check(const T a, const T z) noexcept { - return ( // NaN check - (is_nan(a) || is_nan(z)) ? LIM::quiet_NaN() : - // - a < T(0) ? LIM::quiet_NaN() - : - // - LIM::min() > z ? T(0) - : - // - LIM::min() > a ? T(1) - : - // cf or quadrature - (a < T(10)) && (z - a < T(10)) ? incomplete_gamma_cf_1(a, z) - : (a < T(10)) || (z / a > T(3)) ? incomplete_gamma_cf_2(a, z) - : - // else - incomplete_gamma_quad(a, z)); -} - -template > -constexpr TC incomplete_gamma_type_check(const T1 a, const T2 p) noexcept { - return incomplete_gamma_check(static_cast(a), static_cast(p)); -} + public: + /** + * @brief Compute the CDF for the Gamma distribution. + * + * @param a + * @param z + * @return constexpr T + */ + static constexpr T compute(const T a, const T z) noexcept { + if (is_nan(a) || is_nan(z)) { // NaN check + return LIM::quiet_NaN(); + } else if (a < T(0)) { + return LIM::quiet_NaN(); + } else if (LIM::min() > z) { + return T(0); + } else if (LIM::min() > a) { + return T(1); + } else if (a < T(10) && z - a < T(10)) { + return cf_1(a, z); + } else if (a < T(10) || z / a > T(3)) { + return cf_2(a, z); + } else { + return quadrature(a, z); + } + } +}; } // namespace internal @@ -274,7 +267,9 @@ constexpr TC incomplete_gamma_type_check(const T1 a, const T2 p) noexcept { template constexpr common_return_t incomplete_gamma(const T1 a, const T2 x) noexcept { - return internal::incomplete_gamma_type_check(a, x); + using TC = common_return_t; + return internal::IncompleteGamma::compute(static_cast(a), + static_cast(x)); } namespace internal {