Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Last active December 8, 2024 20:26
Show Gist options
  • Save carlosgmartin/a3055c7605157a54d48d108226a48b97 to your computer and use it in GitHub Desktop.
Save carlosgmartin/a3055c7605157a54d48d108226a48b97 to your computer and use it in GitHub Desktop.
"""
Improving the lowering and compilation of unrolled lax.scan loops
https://github.com/jax-ml/jax/discussions/25336
"""
import argparse
import functools
import jax
from jax import lax, numpy as jnp, random
def tree_get(tree, index):
return jax.tree.map(lambda x: x[index], tree)
def tree_set(tree, index, item):
return jax.tree.map(lambda x, y: x.at[index].set(y), tree, item)
def get_returns(rewards, discounts, unroll=1, scan="lax"):
def f(carry, reward_discount):
reward, discount = reward_discount
new_carry = reward + discount * carry
return new_carry, new_carry
xs = rewards, discounts
match scan:
case "lax":
_, returns = lax.scan(f, 0.0, xs, unroll=unroll, reverse=True)
case "v1":
_, returns = scan_v1(f, 0.0, xs, unroll=unroll, reverse=True)
case "v2":
_, returns = scan_v2(f, 0.0, xs, unroll=unroll, reverse=True)
case _:
raise NotImplementedError
return returns
def repeat_with_outputs(n, f, x, reverse=False):
f = jax.jit(f)
# Prevent creation of near-duplicate StableHLO function due to slight
# differences in signature for the first iteration.
# https://openxla.org/stablehlo/spec#functions
x_abs, _ = jax.eval_shape(f, x)
x = jax.tree.map(lambda x, x_abs: jnp.asarray(x, x_abs.dtype), x, x_abs)
ys = []
for _ in range(n):
x, y = f(x)
ys.append(y)
if reverse:
ys = reversed(ys)
ys = jax.tree.map(lambda *args: jnp.concatenate(args), *ys)
return x, ys
def scan_v1(f, h, xs, unroll=1, reverse=False):
assert unroll is True
@jax.jit
def g(hi):
h, i = hi
xi = tree_get(xs, i)
h, yi = f(h, xi)
yi = jax.tree.map(lambda t: jnp.expand_dims(t, 0), yi)
if reverse:
i -= 1
else:
i += 1
return (h, i), yi
length = jax.tree.leaves(xs)[0].shape[0]
ys = []
if reverse:
i = length - 1
else:
i = 0
(h, i), ys = repeat_with_outputs(length, g, (h, i), reverse)
return h, ys
def repeat(n, f, x):
f = jax.jit(f)
# Prevent creation of near-duplicate StableHLO function due to slight
# differences in signature for the first iteration.
# https://openxla.org/stablehlo/spec#functions
x_abs = jax.eval_shape(f, x)
x = jax.tree.map(lambda x, x_abs: jnp.asarray(x, x_abs.dtype), x, x_abs)
for _ in range(n):
x = f(x)
return x
def scan_v2(f, h, xs, unroll=1, reverse=False):
assert unroll is True
length = jax.tree.leaves(xs)[0].shape[0]
x0 = tree_get(xs, 0)
h_shape, y_shape = jax.eval_shape(f, h, x0)
ys = jax.tree.map(
lambda y: jnp.empty_like(y, shape=(length,) + y.shape), y_shape
)
def g(carry):
h, i, ys = carry
xi = tree_get(xs, i)
h, yi = f(h, xi)
ys = tree_set(ys, i, yi)
if reverse:
i -= 1
else:
i += 1
return h, i, ys
if reverse:
i = length - 1
else:
i = 0
h, i, ys = repeat(length, g, (h, i, ys))
return h, ys
def parse_args():
p = argparse.ArgumentParser(
description="Compare different implementations of jax.lax.scan."
)
p.add_argument(
"--steps",
type=int,
default=10**5,
help="number of steps to scan over",
)
p.add_argument(
"--scan",
default="lax",
help="which scan implementation to use",
choices=["lax", "v1", "v2"],
)
p.add_argument(
"--show_lowered",
type=int,
default=0,
help="show lowered expression",
)
p.add_argument(
"--show_output",
type=int,
default=0,
help="show output on the example inputs",
)
p.add_argument(
"--compile",
type=int,
default=0,
help="compile the lowered expression",
)
return p.parse_args()
def main():
args = parse_args()
key = random.key(0)
keys = random.split(key)
rewards = random.normal(keys[0], [args.steps])
discounts = random.uniform(keys[1], [args.steps])
f = functools.partial(get_returns, unroll=True, scan=args.scan)
if args.show_output:
print(f(rewards, discounts))
# jaxpr = jax.make_jaxpr(f)(rewards, discounts)
# print(jaxpr)
lowered = jax.jit(f).lower(rewards, discounts)
if args.show_lowered:
print(lowered.as_text())
print(f"{len(lowered.as_text())=}")
if args.compile:
compiled = lowered.compile()
print(compiled)
# print(compiled.as_text())
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment