From bebb275489ae6030353148eef881fba9c24ca9b7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 20 Oct 2023 10:21:49 -0400 Subject: [PATCH] compute initial guess for inverse gamma value --- gtsam/nonlinear/internal/chiSquaredInverse.h | 304 +++++++++++++++++-- 1 file changed, 275 insertions(+), 29 deletions(-) diff --git a/gtsam/nonlinear/internal/chiSquaredInverse.h b/gtsam/nonlinear/internal/chiSquaredInverse.h index 528694284..dc8595846 100644 --- a/gtsam/nonlinear/internal/chiSquaredInverse.h +++ b/gtsam/nonlinear/internal/chiSquaredInverse.h @@ -24,9 +24,9 @@ #pragma once -#include -#include -#include +#include +#include +#include #include @@ -37,6 +37,270 @@ namespace gtsam { namespace internal { +/** + * @brief Polynomial evaluation with runtime size. + * + * @tparam T + * @tparam U + */ +template +inline U evaluate_polynomial(const T* poly, U const& z, std::size_t count) { + assert(count > 0); + U sum = static_cast(poly[count - 1]); + for (int i = static_cast(count) - 2; i >= 0; --i) { + sum *= z; + sum += static_cast(poly[i]); + } + return sum; +} + +/** + * @brief Computation of the Incomplete Gamma Function Ratios and their Inverse. + * + * Reference: + * ARMIDO R. DIDONATO and ALFRED H. MORRIS, JR. + * ACM Transactions on Mathematical Software, Vol. 12, No. 4, + * December 1986, Pages 377-393. + * + * See equation 32. + * + * @tparam T + * @param p + * @param q + * @return T + */ +template +T find_inverse_s(T p, T q) { + T t; + if (p < T(0.5)) { + t = sqrt(-2 * log(p)); + } else { + t = sqrt(-2 * log(q)); + } + static const double a[4] = {3.31125922108741, 11.6616720288968, + 4.28342155967104, 0.213623493715853}; + static const double b[5] = {1, 6.61053765625462, 6.40691597760039, + 1.27364489782223, 0.3611708101884203e-1}; + T s = t - internal::evaluate_polynomial(a, t, 4) / + internal::evaluate_polynomial(b, t, 5); + if (p < T(0.5)) s = -s; + return s; +} + +/** + * @brief Computation of the Incomplete Gamma Function Ratios and their Inverse. + * + * Reference: + * ARMIDO R. DIDONATO and ALFRED H. MORRIS, JR. + * ACM Transactions on Mathematical Software, Vol. 12, No. 4, + * December 1986, Pages 377-393. + * + * See equation 34. + * + * @tparam T + * @param a + * @param x + * @param N + * @param tolerance + * @return T + */ +template +T didonato_SN(T a, T x, unsigned N, T tolerance = 0) { + T sum = 1; + if (N >= 1) { + T partial = x / (a + 1); + sum += partial; + for (unsigned i = 2; i <= N; ++i) { + partial *= x / (a + i); + sum += partial; + if (partial < tolerance) break; + } + } + return sum; +} + +/** + * @brief Compute the initial inverse gamma value guess. + * + * We use the implementation in this paper: + * Computation of the Incomplete Gamma Function Ratios and their Inverse + * ARMIDO R. DIDONATO and ALFRED H. MORRIS, JR. + * ACM Transactions on Mathematical Software, Vol. 12, No. 4, + * December 1986, Pages 377-393. + * + * @tparam T + * @param a + * @param p + * @param q + * @param p_has_10_digits + * @return T + */ +template +T find_inverse_gamma(T a, T p, T q, bool* p_has_10_digits) { + T result; + *p_has_10_digits = false; + + // TODO(Varun) replace with egamma_v in C++20 + // Euler-Mascheroni constant + double euler = 0.577215664901532860606512090082402431042159335939923598805; + + if (a == 1) { + result = -log(q); + } else if (a < 1) { + T g = std::tgamma(a); + T b = q * g; + + if ((b > T(0.6)) || ((b >= T(0.45)) && (a >= T(0.3)))) { + // DiDonato & Morris Eq 21: + // + // There is a slight variation from DiDonato and Morris here: + // the first form given here is unstable when p is close to 1, + // making it impossible to compute the inverse of Q(a,x) for small + // q. Fortunately the second form works perfectly well in this case. + T u; + if ((b * q > T(1e-8)) && (q > T(1e-5))) { + u = pow(p * g * a, 1 / a); + } else { + u = exp((-q / a) - euler); + } + result = u / (1 - (u / (a + 1))); + + } else if ((a < 0.3) && (b >= 0.35)) { + // DiDonato & Morris Eq 22: + T t = exp(-euler - b); + T u = t * exp(t); + result = t * exp(u); + + } else if ((b > 0.15) || (a >= 0.3)) { + // DiDonato & Morris Eq 23: + T y = -log(b); + T u = y - (1 - a) * log(y); + result = y - (1 - a) * log(u) - log(1 + (1 - a) / (1 + u)); + + } else if (b > 0.1) { + // DiDonato & Morris Eq 24: + T y = -log(b); + T u = y - (1 - a) * log(y); + result = y - (1 - a) * log(u) - + log((u * u + 2 * (3 - a) * u + (2 - a) * (3 - a)) / + (u * u + (5 - a) * u + 2)); + + } else { + // DiDonato & Morris Eq 25: + T y = -log(b); + T c1 = (a - 1) * log(y); + T c1_2 = c1 * c1; + T c1_3 = c1_2 * c1; + T c1_4 = c1_2 * c1_2; + T a_2 = a * a; + T a_3 = a_2 * a; + + T c2 = (a - 1) * (1 + c1); + T c3 = (a - 1) * (-(c1_2 / 2) + (a - 2) * c1 + (3 * a - 5) / 2); + T c4 = (a - 1) * ((c1_3 / 3) - (3 * a - 5) * c1_2 / 2 + + (a_2 - 6 * a + 7) * c1 + (11 * a_2 - 46 * a + 47) / 6); + T c5 = (a - 1) * (-(c1_4 / 4) + (11 * a - 17) * c1_3 / 6 + + (-3 * a_2 + 13 * a - 13) * c1_2 + + (2 * a_3 - 25 * a_2 + 72 * a - 61) * c1 / 2 + + (25 * a_3 - 195 * a_2 + 477 * a - 379) / 12); + + T y_2 = y * y; + T y_3 = y_2 * y; + T y_4 = y_2 * y_2; + result = y + c1 + (c2 / y) + (c3 / y_2) + (c4 / y_3) + (c5 / y_4); + + if (b < 1e-28f) *p_has_10_digits = true; + } + } else { + // DiDonato and Morris Eq 31: + T s = find_inverse_s(p, q); + + T s_2 = s * s; + T s_3 = s_2 * s; + T s_4 = s_2 * s_2; + T s_5 = s_4 * s; + T ra = sqrt(a); + + T w = a + s * ra + (s * s - 1) / 3; + w += (s_3 - 7 * s) / (36 * ra); + w -= (3 * s_4 + 7 * s_2 - 16) / (810 * a); + w += (9 * s_5 + 256 * s_3 - 433 * s) / (38880 * a * ra); + + if ((a >= 500) && (fabs(1 - w / a) < 1e-6)) { + result = w; + *p_has_10_digits = true; + + } else if (p > 0.5) { + if (w < 3 * a) { + result = w; + + } else { + T D = (std::max)(T(2), T(a * (a - 1))); + T lg = std::lgamma(a); + T lb = log(q) + lg; + if (lb < -D * T(2.3)) { + // DiDonato and Morris Eq 25: + T y = -lb; + T c1 = (a - 1) * log(y); + T c1_2 = c1 * c1; + T c1_3 = c1_2 * c1; + T c1_4 = c1_2 * c1_2; + T a_2 = a * a; + T a_3 = a_2 * a; + + T c2 = (a - 1) * (1 + c1); + T c3 = (a - 1) * (-(c1_2 / 2) + (a - 2) * c1 + (3 * a - 5) / 2); + T c4 = + (a - 1) * ((c1_3 / 3) - (3 * a - 5) * c1_2 / 2 + + (a_2 - 6 * a + 7) * c1 + (11 * a_2 - 46 * a + 47) / 6); + T c5 = (a - 1) * (-(c1_4 / 4) + (11 * a - 17) * c1_3 / 6 + + (-3 * a_2 + 13 * a - 13) * c1_2 + + (2 * a_3 - 25 * a_2 + 72 * a - 61) * c1 / 2 + + (25 * a_3 - 195 * a_2 + 477 * a - 379) / 12); + + T y_2 = y * y; + T y_3 = y_2 * y; + T y_4 = y_2 * y_2; + result = y + c1 + (c2 / y) + (c3 / y_2) + (c4 / y_3) + (c5 / y_4); + + } else { + // DiDonato and Morris Eq 33: + T u = -lb + (a - 1) * log(w) - log(1 + (1 - a) / (1 + w)); + result = -lb + (a - 1) * log(u) - log(1 + (1 - a) / (1 + u)); + } + } + } else { + T z = w; + T ap1 = a + 1; + T ap2 = a + 2; + if (w < 0.15f * ap1) { + // DiDonato and Morris Eq 35: + T v = log(p) + std::lgamma(ap1); + z = exp((v + w) / a); + s = std::log1p(z / ap1 * (1 + z / ap2)); + z = exp((v + z - s) / a); + s = std::log1p(z / ap1 * (1 + z / ap2)); + z = exp((v + z - s) / a); + s = std::log1p(z / ap1 * (1 + z / ap2 * (1 + z / (a + 3)))); + z = exp((v + z - s) / a); + } + + if ((z <= 0.01 * ap1) || (z > 0.7 * ap1)) { + result = z; + if (z <= T(0.002) * ap1) *p_has_10_digits = true; + + } else { + // DiDonato and Morris Eq 36: + T ls = log(didonato_SN(a, z, 100, T(1e-4))); + T v = log(p) + std::lgamma(ap1); + z = exp((v + z - ls) / a); + result = z * (1 - (a * log(z) - z - v + ls) / (a - z)); + } + } + } + return result; +} + template T gamma_p_inv_imp(const T a, const T p) { if (is_nan(a) || is_nan(p)) { @@ -53,13 +317,9 @@ T gamma_p_inv_imp(const T a, const T p) { return 0; } - // TODO // Get an initial guess (https://dl.acm.org/doi/abs/10.1145/22721.23109) - // T guess = find_inverse_gamma(a, p, 1 - p); bool has_10_digits = false; - boost::math::policies::policy<> pol; - T guess = boost::math::detail::find_inverse_gamma(a, p, 1 - p, pol, - &has_10_digits); + T guess = find_inverse_gamma(a, p, 1 - p, &has_10_digits); T lower = LIM::min(); if (guess <= lower) { @@ -67,35 +327,21 @@ T gamma_p_inv_imp(const T a, const T p) { } // TODO + // The number of digits to converge to. + // This is an arbitrary but reasonable number, + // though Boost does more sophisticated things + // using the first derivative. + unsigned digits = 25; + // Number of Halley iterations // The default used in Boost is 200 // uint_fast16_t max_iter = 200; - // The number of digits to converge to. - // This is an arbitrary number, - // but Boost does more sophisticated things - // using the first derivative. - // unsigned digits = 40; - // // Perform Halley iteration for root-finding to get a more refined answer // guess = halley_iterate(gamma_p_inverse_func(a, p, false), guess, lower, // LIM::max(), digits, max_iter); - unsigned digits = - boost::math::policies::digits>(); - if (digits < 30) { - digits *= 2; - digits /= 3; - } else { - digits /= 2; - digits -= 1; - } - if ((a < T(0.125)) && (fabs(boost::math::gamma_p_derivative(a, guess, pol)) > - 1 / sqrt(boost::math::tools::epsilon()))) - digits = - boost::math::policies::digits>() - 2; - // + // Go ahead and iterate: - // std::uintmax_t max_iter = boost::math::policies::get_max_root_iterations< boost::math::policies::policy<>>(); guess = boost::math::tools::halley_iterate(