Last active
September 19, 2024 08:07
-
-
Save carlosgmartin/0ee29182a17b35baf7d402ebdc797486 to your computer and use it in GitHub Desktop.
Implementation of a pseudo-gradient estimator based on JAX and Optax
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import jax | |
import optax | |
from jax import numpy as jnp, random | |
def pseudo_gradient( | |
f, | |
x, | |
key, | |
scale, | |
num_samples=2, | |
sampler=random.normal, | |
antithetic=True, | |
common_key=True, | |
): | |
"""Estimate the pseudo-gradient of a given stochastic function at a given | |
point. This procedure is also known as randomized smoothing. | |
For a non-scalar function, this yields a "pseudo-Jacobian". | |
Args: | |
f: the function, which should take a point and a key. | |
x: the point at which to evaluate the pseudo-gradient. | |
key: a random key. | |
scale: the perturbation scale, or smoothing scale. | |
num_samples: the number of samples to use. | |
sampler: the noise sampling function. ``jax.random.normal`` (the | |
default) yields Gaussian smoothing. ``jax.random.rademacher`` | |
yields simultaneous perturbation stochastic approximation (SPSA). | |
antithetic: use antithetic sampling. This is a version of the variance | |
reduction technique known as antithetic variates. | |
common_key: let antithetic sample pairs use a common key for function | |
evaluation. This is a version of the variance reduction technique | |
known as common random numbers. | |
Returns: | |
an estimate of the pseudo-gradient, which has the same shape as x. | |
References: | |
[Martin and Sandholm, 2024](https://www.arxiv.org/abs/2408.09306) | |
""" | |
if antithetic and num_samples % 2 != 0: | |
raise ValueError("num_samples must be even if antithetic=True") | |
perturb_key, eval_key = random.split(key) | |
if antithetic: | |
perturb_keys = random.split(perturb_key, num_samples // 2) | |
else: | |
perturb_keys = random.split(perturb_key, num_samples) | |
zs = jax.vmap(optax.tree_utils.tree_random_like, [0, None, None])( | |
perturb_keys, x, sampler | |
) | |
if antithetic: | |
zs = jax.tree.map(lambda z: jnp.concatenate([z, -z]), zs) | |
xs = jax.tree.map(lambda x, zs: x + zs * scale, x, zs) | |
if antithetic and common_key: | |
eval_keys = random.split(eval_key, num_samples // 2) | |
eval_keys = jnp.concatenate([eval_keys] * 2) | |
else: | |
eval_keys = random.split(eval_key, num_samples) | |
values = jax.vmap(f)(xs, eval_keys) | |
scale_factor = 1 / (num_samples * scale) | |
def combine(zs): | |
return jnp.tensordot(values, zs, (0, 0)) * scale_factor | |
return jax.tree.map(combine, zs) | |
def smoothe(f, scale, sampler=random.normal): | |
"""Smoothe a stochastic function with a given perturbation scale.""" | |
def g(x, key): | |
keys = random.split(key) | |
z = optax.tree_utils.tree_random_like(keys[0], x, sampler) | |
x = jax.tree.map(lambda x, z: x + z * scale, x, z) | |
return f(x, keys[1]) | |
return g | |
def average(f, num_samples): | |
"""Average a stochastic function over multiple samples.""" | |
def g(x, key): | |
keys = random.split(key, num_samples) | |
outputs = jax.vmap(f, [None, 0])(x, keys) | |
return jax.tree.map(lambda x: x.mean(0), outputs) | |
return g | |
def check_pseudo_gradient(f, x, key, scale, num_samples, atol, rtol): | |
grad = jax.grad(average(smoothe(f, scale), num_samples))(x, key) | |
pgrad = pseudo_gradient(f, x, key, scale, num_samples) | |
assert jnp.allclose( | |
grad, pgrad, atol=atol, rtol=rtol | |
), f"\n{grad}\n{pgrad}" | |
def compare_variances(f, x, key, scale, num_samples): | |
keys = random.split(key, num_samples) | |
print("variances") | |
for params in [ | |
dict(), | |
dict(common_key=False), | |
dict(antithetic=False), | |
]: | |
grads = jax.vmap( | |
partial(pseudo_gradient, **params), [None, None, 0, None] | |
)(f, x, keys, scale) | |
variance = grads.var(0).sum() | |
print(f"{params}: {variance:g}") | |
def main(): | |
def f(x, key): | |
x += random.laplace(key, x.shape) * 0.1 | |
return x @ x | |
x = jnp.linspace(-1, 1, 5) | |
key = random.key(0) | |
for scale in [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1]: | |
check_pseudo_gradient( | |
f, x, key, scale, num_samples=10**6, atol=1e-1, rtol=1e-2 | |
) | |
compare_variances(f, x, key, 1e-1, 10**7) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment