Last active
July 12, 2021 18:37
-
-
Save lazyoracle/258aae32a9612ae1f9211e5fde854677 to your computer and use it in GitHub Desktop.
Numerical Integration using Trapezoidal rule - parallelized with Numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Callable | |
from numba import njit, prange | |
import numpy as np | |
@njit() | |
def my_square(x: float) -> float: | |
return x**2 | |
def bayes_numr(x: float, t: float) -> float: | |
return (1. / np.sqrt(2. * np.pi)) * np.exp(-np.square(x) / 2.) * np.square(np.cos(x * t / 2.0)) | |
@njit(parallel=True) | |
def p_total_prob(low: float, high: float, samples: int, t: float) -> float: | |
"""Calculate the definite integral | |
N_x(0, 1) * sin^2 (x.t/2) dx between low and high | |
Using the trapezoidal rule | |
Parameters | |
---------- | |
low : float | |
lower bound of integral | |
high : float | |
upper bound of integral | |
samples : int | |
number of samples for approximation | |
t : float | |
argument for the integrand | |
Returns | |
------- | |
float | |
The definite integral | |
""" | |
del_x = (high - low) / samples | |
f_0 = bayes_numr(low, t) | |
f_n = bayes_numr(high, t) | |
integral = f_0 + f_n | |
for i in prange(samples): | |
temp_arg = (low + (del_x * i)) | |
integral += 2. * bayes_numr(temp_arg, t) | |
integral = integral * del_x / 2. | |
return integral | |
@njit(parallel=True) | |
def trap_integrator(integrand: Callable, low: float, high: float, samples: int, *args: float) -> float: | |
"""Calculate definite integral using Trapezoidal Rule | |
Parameters | |
---------- | |
integrand : Callable | |
The function to integrate | |
low : float | |
Lower bound of integration | |
high : float | |
Upper bound of integration | |
samples : int | |
Number of samples | |
Returns | |
------- | |
float | |
Result of the definite integral | |
""" | |
del_x = (high - low) / samples | |
f_0 = integrand(low, *args) | |
f_n = integrand(high, *args) | |
integral = f_0 + f_n | |
for i in prange(samples): | |
temp_arg = (low + (del_x * i)) | |
integral += 2. * integrand(temp_arg, *args) | |
integral = integral * del_x / 2. | |
return integral | |
@njit() | |
def normal_func(x: float, mean: float, sigma: float) -> float: | |
"""The Normal Dist function in x for given mean and sigma | |
Parameters | |
---------- | |
x : float | |
Random variable for the distribution | |
mean : float | |
Mean of the distribution | |
sigma : float | |
Standard Deviation of the distribution | |
Returns | |
------- | |
float | |
The probability of getting x when sampled from this dist | |
""" | |
return (1. / (sigma * np.sqrt(2. * np.pi))) * np.exp(-(1/2.) * np.square((x - mean) / sigma)) | |
print(trap_integrator(njit(bayes_numr), -1e5, 1e5, int(1e9), np.pi)) | |
print(trap_integrator(njit(lambda a, b: (a**2) + b), 0, 1, int(1e12), 5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment