From 25ebdd54fc23ca0624a42540e941cd04abc7eb9e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 22 Oct 2023 15:15:02 -0400 Subject: [PATCH] add gamma_p_inverse_func --- gtsam/nonlinear/internal/chiSquaredInverse.h | 72 ++++++++++++++++++-- 1 file changed, 67 insertions(+), 5 deletions(-) diff --git a/gtsam/nonlinear/internal/chiSquaredInverse.h b/gtsam/nonlinear/internal/chiSquaredInverse.h index b7744ffa2..7577721dc 100644 --- a/gtsam/nonlinear/internal/chiSquaredInverse.h +++ b/gtsam/nonlinear/internal/chiSquaredInverse.h @@ -24,13 +24,12 @@ #pragma once -#include #include #include // TODO(Varun) remove -// #include +#include #include namespace gtsam { @@ -301,6 +300,65 @@ T find_inverse_gamma(T a, T p, T q, bool* p_has_10_digits) { return result; } +/** + * @brief Functional to compute the gamma inverse. + * Mainly used with Halley iteration. + * + * @tparam T + */ +template +struct gamma_p_inverse_func { + gamma_p_inverse_func(T a_, T p_, bool inv) : a(a_), p(p_), invert(inv) { + /* + If p is too near 1 then P(x) - p suffers from cancellation + errors causing our root-finding algorithms to "thrash", better + to invert in this case and calculate Q(x) - (1-p) instead. + + Of course if p is *very* close to 1, then the answer we get will + be inaccurate anyway (because there's not enough information in p) + but at least we will converge on the (inaccurate) answer quickly. + */ + if (p > T(0.9)) { + p = 1 - p; + invert = !invert; + } + } + + std::tuple operator()(const T& x) const { + // Calculate P(x) - p and the first two derivates, or if the invert + // flag is set, then Q(x) - q and it's derivatives. + T f, f1; + T ft; + f = static_cast(gamma_incomplete_imp(a, x, true, invert, &ft)); + f1 = ft; + T f2; + T div = (a - x - 1) / x; + f2 = f1; + + if (fabs(div) > 1) { + if (internal::LIM::max() / fabs(div) < f2) { + // overflow: + f2 = -internal::LIM::max() / 2; + } else { + f2 *= div; + } + } else { + f2 *= div; + } + + if (invert) { + f1 = -f1; + f2 = -f2; + } + + return std::make_tuple(static_cast(f - p), f1, f2); + } + + private: + T a, p; + bool invert; +}; + template T gamma_p_inv_imp(const T a, const T p) { if (is_nan(a) || is_nan(p)) { @@ -339,14 +397,18 @@ T gamma_p_inv_imp(const T a, const T p) { uintmax_t max_iter = 200; // TODO - // // Perform Halley iteration for root-finding to get a more refined answer + // Perform Halley iteration for root-finding to get a more refined answer // guess = halley_iterate(gamma_p_inverse_func(a, p, false), guess, lower, // LIM::max(), digits, max_iter); // Go ahead and iterate: + // guess = boost::math::tools::halley_iterate( + // internal::gamma_p_inverse_func(a, p, false), guess, lower, + // LIM::max(), digits, max_iter); guess = boost::math::tools::halley_iterate( - internal::gamma_p_inverse_func(a, p, false), guess, lower, - LIM::max(), digits, max_iter); + boost::math::detail::gamma_p_inverse_func< + T, boost::math::policies::policy<>>(a, p, false), + guess, lower, boost::math::tools::max_value(), digits, max_iter); if (guess == lower) { throw std::runtime_error(