add gamma_p_inverse_func

release/4.3a0
Varun Agrawal 2023-10-22 15:15:02 -04:00
parent 6f386168a4
commit 25ebdd54fc
1 changed files with 67 additions and 5 deletions

View File

@ -24,13 +24,12 @@
#pragma once
#include <gtsam/nonlinear/internal/Gamma.h>
#include <gtsam/nonlinear/internal/Utils.h>
#include <algorithm>
// TODO(Varun) remove
// #include <gtsam/nonlinear/internal/Halley.h>
#include <boost/math/special_functions/gamma.hpp>
#include <boost/math/tools/roots.hpp>
namespace gtsam {
@ -301,6 +300,65 @@ T find_inverse_gamma(T a, T p, T q, bool* p_has_10_digits) {
return result;
}
/**
* @brief Functional to compute the gamma inverse.
* Mainly used with Halley iteration.
*
* @tparam T
*/
template <class T>
struct gamma_p_inverse_func {
gamma_p_inverse_func(T a_, T p_, bool inv) : a(a_), p(p_), invert(inv) {
/*
If p is too near 1 then P(x) - p suffers from cancellation
errors causing our root-finding algorithms to "thrash", better
to invert in this case and calculate Q(x) - (1-p) instead.
Of course if p is *very* close to 1, then the answer we get will
be inaccurate anyway (because there's not enough information in p)
but at least we will converge on the (inaccurate) answer quickly.
*/
if (p > T(0.9)) {
p = 1 - p;
invert = !invert;
}
}
std::tuple<T, T, T> operator()(const T& x) const {
// Calculate P(x) - p and the first two derivates, or if the invert
// flag is set, then Q(x) - q and it's derivatives.
T f, f1;
T ft;
f = static_cast<T>(gamma_incomplete_imp(a, x, true, invert, &ft));
f1 = ft;
T f2;
T div = (a - x - 1) / x;
f2 = f1;
if (fabs(div) > 1) {
if (internal::LIM<T>::max() / fabs(div) < f2) {
// overflow:
f2 = -internal::LIM<T>::max() / 2;
} else {
f2 *= div;
}
} else {
f2 *= div;
}
if (invert) {
f1 = -f1;
f2 = -f2;
}
return std::make_tuple(static_cast<T>(f - p), f1, f2);
}
private:
T a, p;
bool invert;
};
template <typename T>
T gamma_p_inv_imp(const T a, const T p) {
if (is_nan(a) || is_nan(p)) {
@ -339,14 +397,18 @@ T gamma_p_inv_imp(const T a, const T p) {
uintmax_t max_iter = 200;
// TODO
// // Perform Halley iteration for root-finding to get a more refined answer
// Perform Halley iteration for root-finding to get a more refined answer
// guess = halley_iterate(gamma_p_inverse_func<T>(a, p, false), guess, lower,
// LIM<T>::max(), digits, max_iter);
// Go ahead and iterate:
// guess = boost::math::tools::halley_iterate(
// internal::gamma_p_inverse_func<T>(a, p, false), guess, lower,
// LIM<T>::max(), digits, max_iter);
guess = boost::math::tools::halley_iterate(
internal::gamma_p_inverse_func<T>(a, p, false), guess, lower,
LIM<T>::max(), digits, max_iter);
boost::math::detail::gamma_p_inverse_func<
T, boost::math::policies::policy<>>(a, p, false),
guess, lower, boost::math::tools::max_value<T>(), digits, max_iter);
if (guess == lower) {
throw std::runtime_error(