Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Last active September 19, 2024 08:07
Show Gist options
  • Save carlosgmartin/0ee29182a17b35baf7d402ebdc797486 to your computer and use it in GitHub Desktop.
Save carlosgmartin/0ee29182a17b35baf7d402ebdc797486 to your computer and use it in GitHub Desktop.
Implementation of a pseudo-gradient estimator based on JAX and Optax
from functools import partial
import jax
import optax
from jax import numpy as jnp, random
def pseudo_gradient(
f,
x,
key,
scale,
num_samples=2,
sampler=random.normal,
antithetic=True,
common_key=True,
):
"""Estimate the pseudo-gradient of a given stochastic function at a given
point. This procedure is also known as randomized smoothing.
For a non-scalar function, this yields a "pseudo-Jacobian".
Args:
f: the function, which should take a point and a key.
x: the point at which to evaluate the pseudo-gradient.
key: a random key.
scale: the perturbation scale, or smoothing scale.
num_samples: the number of samples to use.
sampler: the noise sampling function. ``jax.random.normal`` (the
default) yields Gaussian smoothing. ``jax.random.rademacher``
yields simultaneous perturbation stochastic approximation (SPSA).
antithetic: use antithetic sampling. This is a version of the variance
reduction technique known as antithetic variates.
common_key: let antithetic sample pairs use a common key for function
evaluation. This is a version of the variance reduction technique
known as common random numbers.
Returns:
an estimate of the pseudo-gradient, which has the same shape as x.
References:
[Martin and Sandholm, 2024](https://www.arxiv.org/abs/2408.09306)
"""
if antithetic and num_samples % 2 != 0:
raise ValueError("num_samples must be even if antithetic=True")
perturb_key, eval_key = random.split(key)
if antithetic:
perturb_keys = random.split(perturb_key, num_samples // 2)
else:
perturb_keys = random.split(perturb_key, num_samples)
zs = jax.vmap(optax.tree_utils.tree_random_like, [0, None, None])(
perturb_keys, x, sampler
)
if antithetic:
zs = jax.tree.map(lambda z: jnp.concatenate([z, -z]), zs)
xs = jax.tree.map(lambda x, zs: x + zs * scale, x, zs)
if antithetic and common_key:
eval_keys = random.split(eval_key, num_samples // 2)
eval_keys = jnp.concatenate([eval_keys] * 2)
else:
eval_keys = random.split(eval_key, num_samples)
values = jax.vmap(f)(xs, eval_keys)
scale_factor = 1 / (num_samples * scale)
def combine(zs):
return jnp.tensordot(values, zs, (0, 0)) * scale_factor
return jax.tree.map(combine, zs)
def smoothe(f, scale, sampler=random.normal):
"""Smoothe a stochastic function with a given perturbation scale."""
def g(x, key):
keys = random.split(key)
z = optax.tree_utils.tree_random_like(keys[0], x, sampler)
x = jax.tree.map(lambda x, z: x + z * scale, x, z)
return f(x, keys[1])
return g
def average(f, num_samples):
"""Average a stochastic function over multiple samples."""
def g(x, key):
keys = random.split(key, num_samples)
outputs = jax.vmap(f, [None, 0])(x, keys)
return jax.tree.map(lambda x: x.mean(0), outputs)
return g
def check_pseudo_gradient(f, x, key, scale, num_samples, atol, rtol):
grad = jax.grad(average(smoothe(f, scale), num_samples))(x, key)
pgrad = pseudo_gradient(f, x, key, scale, num_samples)
assert jnp.allclose(
grad, pgrad, atol=atol, rtol=rtol
), f"\n{grad}\n{pgrad}"
def compare_variances(f, x, key, scale, num_samples):
keys = random.split(key, num_samples)
print("variances")
for params in [
dict(),
dict(common_key=False),
dict(antithetic=False),
]:
grads = jax.vmap(
partial(pseudo_gradient, **params), [None, None, 0, None]
)(f, x, keys, scale)
variance = grads.var(0).sum()
print(f"{params}: {variance:g}")
def main():
def f(x, key):
x += random.laplace(key, x.shape) * 0.1
return x @ x
x = jnp.linspace(-1, 1, 5)
key = random.key(0)
for scale in [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1]:
check_pseudo_gradient(
f, x, key, scale, num_samples=10**6, atol=1e-1, rtol=1e-2
)
compare_variances(f, x, key, 1e-1, 10**7)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment