-
-
Save nouiz/2038703 to your computer and use it in GitHub Desktop.
lngamma in theano
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.tensor as T | |
def log_gamma_lanczos(z): | |
assert z.dtype.startswith("float") | |
# reflection formula. Normally only used for negative arguments, | |
# but here it's also used for 0 < z < 0.5 to improve accuracy in | |
# this region. | |
flip_z = 1 - z | |
# because both paths are always executed (reflected and | |
# non-reflected), the reflection formula causes trouble when the | |
# input argument is larger than one. | |
# Note that for any z > 1, flip_z < 0. | |
# To prevent these problems, we simply set all flip_z < 0 to a | |
# 'dummy' value. This is not a problem, since these computations | |
# are useless anyway and are discarded by the T.switch at the end | |
# of the function. | |
flip_z = T.switch(flip_z < 0, 1, flip_z) | |
log_pi = np.asarray(np.log(np.pi), dtype=z.dtype) | |
small = log_pi - T.log(T.sin(np.pi * z)) - log_gamma_lanczos_sub(flip_z) | |
big = log_gamma_lanczos_sub(z) | |
return T.switch(z < 0.5, small, big) | |
def log_gamma_lanczos_sub(z): # vectorised version | |
# Coefficients used by the GNU Scientific Library | |
g = 7 | |
p = np.array([0.99999999999980993, 676.5203681218851, -1259.1392167224028, | |
771.32342877765313, -176.61502916214059, 12.507343278686905, | |
-0.13857109526572012, 9.9843695780195716e-6, | |
1.5056327351493116e-7], dtype=z.dtype) | |
z = z - 1 | |
zs = T.shape_padleft(z, 1) + T.shape_padright(T.arange(1, g + 2), z.ndim) | |
x = T.sum(T.shape_padright(p[1:], z.ndim) / zs, axis=0) + p[0] | |
t = z + g + 0.5 | |
log_sqrt_2pi = np.asarray(np.log(np.sqrt(2 * np.pi)), dtype=z.dtype) | |
return log_sqrt_2pi + (z + 0.5) * T.log(t) - t + T.log(x) | |
## alternative version that isn't vectorised, since g is small anyway | |
def log_gamma_lanczos_sub(z): # expanded version | |
# Coefficients used by the GNU Scientific Library | |
g = 7 | |
p = np.array([0.99999999999980993, 676.5203681218851, -1259.1392167224028, | |
771.32342877765313, -176.61502916214059, 12.507343278686905, | |
-0.13857109526572012, 9.9843695780195716e-6, | |
1.5056327351493116e-7], dtype=z.dtype) | |
z = z - 1 | |
x = p[0] | |
for i in range(1, g + 2): | |
x += p[i] / (z + i) | |
t = z + g + 0.5 | |
pi = np.asarray(np.pi, dtype=z.dtype) | |
log_sqrt_2pi = np.asarray(np.log(np.sqrt(2 * np.pi)), dtype=z.dtype) | |
return log_sqrt_2pi + (z + 0.5) * T.log(t) - t + T.log(x) | |
if __name__ == "__main__": | |
x = T.vector() | |
theano.printing.debugprint(log_gamma_lanczos(x), print_type=True) | |
f = theano.function([x], log_gamma_lanczos(x)) | |
theano.printing.debugprint(f, print_type=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment