Created
September 19, 2024 02:45
-
-
Save carlosgmartin/a147b43f39633dcb0a985b51a5b1af0c to your computer and use it in GitHub Desktop.
JAX implementation of finite differences (forward and central)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
from jax import numpy as jnp | |
from jax.flatten_util import ravel_pytree | |
def centered_finite_differences(f, x, eps): | |
x, unravel = ravel_pytree(x) | |
def helper(index): | |
xp = x.at[index].add(eps) | |
xm = x.at[index].add(-eps) | |
yp = f(unravel(xp)) | |
ym = f(unravel(xm)) | |
return yp - ym | |
g = jax.vmap(helper)(jnp.arange(x.size)) | |
g /= eps * 2 | |
return unravel(g) | |
def forward_finite_differences(f, x, eps): | |
x, unravel = ravel_pytree(x) | |
def helper(index): | |
xp = x.at[index].add(eps) | |
yp = f(unravel(xp)) | |
return yp | |
g = jax.vmap(helper)(jnp.arange(x.size)) | |
y = f(unravel(x)) | |
g -= y | |
g /= eps | |
return unravel(g) | |
def tree_all_close(a, b, rtol=1e-5, atol=1e-8, equal_nan=False): | |
def close(a, b): | |
return jnp.allclose(a, b, rtol, atol, equal_nan) | |
tree = jax.tree.map(close, a, b) | |
return jax.tree.all(tree) | |
def main(): | |
def f(x): | |
return x["a"] @ x["b"] | |
x = { | |
"a": jnp.linspace(-1, 1, 5), | |
"b": jnp.linspace(-2, 0, 5), | |
} | |
a = jax.grad(f)(x) | |
b = centered_finite_differences(f, x, 1e-5) | |
c = forward_finite_differences(f, x, 1e-5) | |
print(a) | |
print(b) | |
print(c) | |
assert tree_all_close(a, b, atol=1e-2) | |
assert tree_all_close(a, c, atol=1e-2) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment