Last active
December 8, 2024 20:26
-
-
Save carlosgmartin/a3055c7605157a54d48d108226a48b97 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Improving the lowering and compilation of unrolled lax.scan loops | |
https://github.com/jax-ml/jax/discussions/25336 | |
""" | |
import argparse | |
import functools | |
import jax | |
from jax import lax, numpy as jnp, random | |
def tree_get(tree, index): | |
return jax.tree.map(lambda x: x[index], tree) | |
def tree_set(tree, index, item): | |
return jax.tree.map(lambda x, y: x.at[index].set(y), tree, item) | |
def get_returns(rewards, discounts, unroll=1, scan="lax"): | |
def f(carry, reward_discount): | |
reward, discount = reward_discount | |
new_carry = reward + discount * carry | |
return new_carry, new_carry | |
xs = rewards, discounts | |
match scan: | |
case "lax": | |
_, returns = lax.scan(f, 0.0, xs, unroll=unroll, reverse=True) | |
case "v1": | |
_, returns = scan_v1(f, 0.0, xs, unroll=unroll, reverse=True) | |
case "v2": | |
_, returns = scan_v2(f, 0.0, xs, unroll=unroll, reverse=True) | |
case _: | |
raise NotImplementedError | |
return returns | |
def repeat_with_outputs(n, f, x, reverse=False): | |
f = jax.jit(f) | |
# Prevent creation of near-duplicate StableHLO function due to slight | |
# differences in signature for the first iteration. | |
# https://openxla.org/stablehlo/spec#functions | |
x_abs, _ = jax.eval_shape(f, x) | |
x = jax.tree.map(lambda x, x_abs: jnp.asarray(x, x_abs.dtype), x, x_abs) | |
ys = [] | |
for _ in range(n): | |
x, y = f(x) | |
ys.append(y) | |
if reverse: | |
ys = reversed(ys) | |
ys = jax.tree.map(lambda *args: jnp.concatenate(args), *ys) | |
return x, ys | |
def scan_v1(f, h, xs, unroll=1, reverse=False): | |
assert unroll is True | |
@jax.jit | |
def g(hi): | |
h, i = hi | |
xi = tree_get(xs, i) | |
h, yi = f(h, xi) | |
yi = jax.tree.map(lambda t: jnp.expand_dims(t, 0), yi) | |
if reverse: | |
i -= 1 | |
else: | |
i += 1 | |
return (h, i), yi | |
length = jax.tree.leaves(xs)[0].shape[0] | |
ys = [] | |
if reverse: | |
i = length - 1 | |
else: | |
i = 0 | |
(h, i), ys = repeat_with_outputs(length, g, (h, i), reverse) | |
return h, ys | |
def repeat(n, f, x): | |
f = jax.jit(f) | |
# Prevent creation of near-duplicate StableHLO function due to slight | |
# differences in signature for the first iteration. | |
# https://openxla.org/stablehlo/spec#functions | |
x_abs = jax.eval_shape(f, x) | |
x = jax.tree.map(lambda x, x_abs: jnp.asarray(x, x_abs.dtype), x, x_abs) | |
for _ in range(n): | |
x = f(x) | |
return x | |
def scan_v2(f, h, xs, unroll=1, reverse=False): | |
assert unroll is True | |
length = jax.tree.leaves(xs)[0].shape[0] | |
x0 = tree_get(xs, 0) | |
h_shape, y_shape = jax.eval_shape(f, h, x0) | |
ys = jax.tree.map( | |
lambda y: jnp.empty_like(y, shape=(length,) + y.shape), y_shape | |
) | |
def g(carry): | |
h, i, ys = carry | |
xi = tree_get(xs, i) | |
h, yi = f(h, xi) | |
ys = tree_set(ys, i, yi) | |
if reverse: | |
i -= 1 | |
else: | |
i += 1 | |
return h, i, ys | |
if reverse: | |
i = length - 1 | |
else: | |
i = 0 | |
h, i, ys = repeat(length, g, (h, i, ys)) | |
return h, ys | |
def parse_args(): | |
p = argparse.ArgumentParser( | |
description="Compare different implementations of jax.lax.scan." | |
) | |
p.add_argument( | |
"--steps", | |
type=int, | |
default=10**5, | |
help="number of steps to scan over", | |
) | |
p.add_argument( | |
"--scan", | |
default="lax", | |
help="which scan implementation to use", | |
choices=["lax", "v1", "v2"], | |
) | |
p.add_argument( | |
"--show_lowered", | |
type=int, | |
default=0, | |
help="show lowered expression", | |
) | |
p.add_argument( | |
"--show_output", | |
type=int, | |
default=0, | |
help="show output on the example inputs", | |
) | |
p.add_argument( | |
"--compile", | |
type=int, | |
default=0, | |
help="compile the lowered expression", | |
) | |
return p.parse_args() | |
def main(): | |
args = parse_args() | |
key = random.key(0) | |
keys = random.split(key) | |
rewards = random.normal(keys[0], [args.steps]) | |
discounts = random.uniform(keys[1], [args.steps]) | |
f = functools.partial(get_returns, unroll=True, scan=args.scan) | |
if args.show_output: | |
print(f(rewards, discounts)) | |
# jaxpr = jax.make_jaxpr(f)(rewards, discounts) | |
# print(jaxpr) | |
lowered = jax.jit(f).lower(rewards, discounts) | |
if args.show_lowered: | |
print(lowered.as_text()) | |
print(f"{len(lowered.as_text())=}") | |
if args.compile: | |
compiled = lowered.compile() | |
print(compiled) | |
# print(compiled.as_text()) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment