Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Created September 19, 2024 02:45
Show Gist options
  • Save carlosgmartin/a147b43f39633dcb0a985b51a5b1af0c to your computer and use it in GitHub Desktop.
Save carlosgmartin/a147b43f39633dcb0a985b51a5b1af0c to your computer and use it in GitHub Desktop.
JAX implementation of finite differences (forward and central)
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
def centered_finite_differences(f, x, eps):
x, unravel = ravel_pytree(x)
def helper(index):
xp = x.at[index].add(eps)
xm = x.at[index].add(-eps)
yp = f(unravel(xp))
ym = f(unravel(xm))
return yp - ym
g = jax.vmap(helper)(jnp.arange(x.size))
g /= eps * 2
return unravel(g)
def forward_finite_differences(f, x, eps):
x, unravel = ravel_pytree(x)
def helper(index):
xp = x.at[index].add(eps)
yp = f(unravel(xp))
return yp
g = jax.vmap(helper)(jnp.arange(x.size))
y = f(unravel(x))
g -= y
g /= eps
return unravel(g)
def tree_all_close(a, b, rtol=1e-5, atol=1e-8, equal_nan=False):
def close(a, b):
return jnp.allclose(a, b, rtol, atol, equal_nan)
tree = jax.tree.map(close, a, b)
return jax.tree.all(tree)
def main():
def f(x):
return x["a"] @ x["b"]
x = {
"a": jnp.linspace(-1, 1, 5),
"b": jnp.linspace(-2, 0, 5),
}
a = jax.grad(f)(x)
b = centered_finite_differences(f, x, 1e-5)
c = forward_finite_differences(f, x, 1e-5)
print(a)
print(b)
print(c)
assert tree_all_close(a, b, atol=1e-2)
assert tree_all_close(a, c, atol=1e-2)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment