gamma inverse functional

release/4.3a0
Varun Agrawal 2023-10-20 11:09:35 -04:00
parent bebb275489
commit 6f386168a4
3 changed files with 152 additions and 10 deletions

View File

@ -0,0 +1,94 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Gamma.h
* @brief Gamma and Gamma Inverse functions
*
* A lot of this code has been picked up from
* https://www.boost.org/doc/libs/1_83_0/boost/math/special_functions/detail/igamma_inverse.hpp
*
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/nonlinear/internal/Utils.h>
#include <boost/math/special_functions/gamma.hpp>
namespace gtsam {
namespace internal {
/**
* @brief Functional to compute the gamma inverse.
* Mainly used with Halley iteration.
*
* @tparam T
*/
template <class T>
struct gamma_p_inverse_func {
gamma_p_inverse_func(T a_, T p_, bool inv) : a(a_), p(p_), invert(inv) {
/*
If p is too near 1 then P(x) - p suffers from cancellation
errors causing our root-finding algorithms to "thrash", better
to invert in this case and calculate Q(x) - (1-p) instead.
Of course if p is *very* close to 1, then the answer we get will
be inaccurate anyway (because there's not enough information in p)
but at least we will converge on the (inaccurate) answer quickly.
*/
if (p > T(0.9)) {
p = 1 - p;
invert = !invert;
}
}
std::tuple<T, T, T> operator()(const T& x) const {
// Calculate P(x) - p and the first two derivates, or if the invert
// flag is set, then Q(x) - q and it's derivatives.
T f, f1;
T ft;
boost::math::policies::policy<> pol;
f = static_cast<T>(boost::math::detail::gamma_incomplete_imp(
a, x, true, invert, pol, &ft));
f1 = ft;
T f2;
T div = (a - x - 1) / x;
f2 = f1;
if (fabs(div) > 1) {
if (internal::LIM<T>::max() / fabs(div) < f2) {
// overflow:
f2 = -internal::LIM<T>::max() / 2;
} else {
f2 *= div;
}
} else {
f2 *= div;
}
if (invert) {
f1 = -f1;
f2 = -f2;
}
return std::make_tuple(static_cast<T>(f - p), f1, f2);
}
private:
T a, p;
bool invert;
};
} // namespace internal
} // namespace gtsam

View File

@ -0,0 +1,49 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Utils.h
* @brief Utilities for the Chi Squared inverse and related operations.
* @author Varun Agrawal
*/
#pragma once
namespace gtsam {
namespace internal {
/// Template type for numeric limits
template <class T>
using LIM = std::numeric_limits<T>;
template <typename T>
using return_t =
typename std::conditional<std::is_integral<T>::value, double, T>::type;
/// Get common type amongst all arguments
template <typename... T>
using common_t = typename std::common_type<T...>::type;
/// Helper template for finding common return type
template <typename... T>
using common_return_t = return_t<common_t<T...>>;
/// Check if integer is odd
constexpr bool is_odd(const long long int x) noexcept { return (x & 1U) != 0; }
/// Templated check for NaN
template <typename T>
constexpr bool is_nan(const T x) noexcept {
return x != x;
}
} // namespace internal
} // namespace gtsam

View File

@ -25,13 +25,13 @@
#pragma once #pragma once
#include <gtsam/nonlinear/internal/Gamma.h> #include <gtsam/nonlinear/internal/Gamma.h>
#include <gtsam/nonlinear/internal/Halley.h>
#include <gtsam/nonlinear/internal/Utils.h> #include <gtsam/nonlinear/internal/Utils.h>
#include <algorithm> #include <algorithm>
// TODO(Varun) remove // TODO(Varun) remove
#include <boost/math/distributions/gamma.hpp> // #include <gtsam/nonlinear/internal/Halley.h>
#include <boost/math/tools/roots.hpp>
namespace gtsam { namespace gtsam {
@ -320,13 +320,15 @@ T gamma_p_inv_imp(const T a, const T p) {
// Get an initial guess (https://dl.acm.org/doi/abs/10.1145/22721.23109) // Get an initial guess (https://dl.acm.org/doi/abs/10.1145/22721.23109)
bool has_10_digits = false; bool has_10_digits = false;
T guess = find_inverse_gamma<T>(a, p, 1 - p, &has_10_digits); T guess = find_inverse_gamma<T>(a, p, 1 - p, &has_10_digits);
if (has_10_digits) {
return guess;
}
T lower = LIM<T>::min(); T lower = LIM<T>::min();
if (guess <= lower) { if (guess <= lower) {
guess = LIM<T>::min(); guess = LIM<T>::min();
} }
// TODO
// The number of digits to converge to. // The number of digits to converge to.
// This is an arbitrary but reasonable number, // This is an arbitrary but reasonable number,
// though Boost does more sophisticated things // though Boost does more sophisticated things
@ -334,20 +336,17 @@ T gamma_p_inv_imp(const T a, const T p) {
unsigned digits = 25; unsigned digits = 25;
// Number of Halley iterations // Number of Halley iterations
// The default used in Boost is 200 uintmax_t max_iter = 200;
// uint_fast16_t max_iter = 200;
// TODO
// // Perform Halley iteration for root-finding to get a more refined answer // // Perform Halley iteration for root-finding to get a more refined answer
// guess = halley_iterate(gamma_p_inverse_func<T>(a, p, false), guess, lower, // guess = halley_iterate(gamma_p_inverse_func<T>(a, p, false), guess, lower,
// LIM<T>::max(), digits, max_iter); // LIM<T>::max(), digits, max_iter);
// Go ahead and iterate: // Go ahead and iterate:
std::uintmax_t max_iter = boost::math::policies::get_max_root_iterations<
boost::math::policies::policy<>>();
guess = boost::math::tools::halley_iterate( guess = boost::math::tools::halley_iterate(
boost::math::detail::gamma_p_inverse_func< internal::gamma_p_inverse_func<T>(a, p, false), guess, lower,
T, boost::math::policies::policy<>>(a, p, false), LIM<T>::max(), digits, max_iter);
guess, lower, boost::math::tools::max_value<T>(), digits, max_iter);
if (guess == lower) { if (guess == lower) {
throw std::runtime_error( throw std::runtime_error(