Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Last active November 20, 2024 04:46
Show Gist options
  • Save carlosgmartin/7a61a80f6340280734a381300534c4d3 to your computer and use it in GitHub Desktop.
Save carlosgmartin/7a61a80f6340280734a381300534c4d3 to your computer and use it in GitHub Desktop.
Comparison of JAX implementations of the Hungarian algorithm
from timeit import default_timer
import jax
import optax
import tqdm
from jax import lax, numpy as jnp, random
@jax.jit
def scenic_hungarian_algorithm(cost):
"""Hungarian matcher for a single example."""
is_transpose = cost.shape[0] > cost.shape[1]
if is_transpose:
cost = cost.T
n, m = cost.shape
one_hot_m = jnp.eye(m + 1)
def row_scan_fn(state, i):
"""Loop over the rows of the cost matrix."""
u, v, parent = state
# parent[0] = i; note that i runs from 1 to n inclusive
parent = jax.lax.dynamic_update_index_in_dim(parent, i, 0, axis=0)
def dfs_body_fn(state):
# Row potential, column potential, used array, support array, path
# array column index j0.
u, v, used, minv, way, j0 = state
# Mark column as used
used = jnp.logical_or(used, one_hot_m[j0])
used_slice = used[1:]
# Row paired to column j0
i0 = parent[j0]
# Update minv and path to it
cur = cost[i0 - 1, :] - u[i0] - v[1:]
cur = jnp.where(used_slice, jnp.full_like(cur, 1e10), cur)
way = jnp.where(cur < minv, jnp.full_like(way, j0), way)
minv = jnp.where(cur < minv, cur, minv) # type: ignore
# When finding an index with minimal minv, we need to mask out the
# visited rows
masked_minv = jnp.where(
used_slice, jnp.full_like(minv, 1e10), minv
)
j1 = jnp.argmin(masked_minv) + 1 # type: ignore
delta = jnp.min(
minv, initial=1e10, where=jnp.logical_not(used_slice)
)
# Update potentials
indices = jnp.where(
used, parent, n + 1
) # deliberately out of bounds
u = u.at[indices].add(delta)
v = jnp.where(used, v - delta, v)
minv = jnp.where(jnp.logical_not(used_slice), minv - delta, minv)
return (u, v, used, minv, way, j1)
def dfs_cond_fn(state):
_, _, _, _, _, j0 = state
return parent[j0] != 0
# Run the inner while loop (i.e. DFS)
way = jnp.zeros((m,), dtype=jnp.int32)
used = jnp.zeros((m + 1,), dtype=jnp.bool_)
minv = jnp.full((m,), 1e10, dtype=jnp.float32)
init_state = (u, v, used, minv, way, 0)
state = jax.lax.while_loop(dfs_cond_fn, dfs_body_fn, init_state)
u, v, _, _, way, j0 = state
def update_parent_body_fn(state):
"""Update parents based on the DFS path."""
parent, j0 = state
j1 = way[j0 - 1]
parent = jax.lax.dynamic_update_index_in_dim(
parent, parent[j1], j0, axis=0
)
return (parent, j1)
def update_parent_cond_fn(state):
"""Condition function counterpart."""
_, j0 = state
return j0 != 0
# Backtrack the DFS path
init_state = (parent, j0)
parent, _ = jax.lax.while_loop(
update_parent_cond_fn, update_parent_body_fn, init_state
)
return (u, v, parent), None
# Define the initial state
u = jnp.zeros((n + 2,), dtype=jnp.float32)
v = jnp.zeros((m + 1,), dtype=jnp.float32)
parent = jnp.zeros((m + 1,), dtype=jnp.int32)
init_state = (u, v, parent)
(u, v, parent), _ = jax.lax.scan(
row_scan_fn, init_state, jnp.arange(1, n + 1)
)
# -v[0] is the matching cost, but not returned to match the signature all
# other matchers.
if n != m:
# This is a costly operation, so skip it when possible (i.e. for square
# cost matrices).
parent, indices = jax.lax.top_k(parent[1:], n)
else:
parent, indices = parent[1:], jnp.arange(n)
parent = parent - 1 # Switch back to 0-based indexing.
if is_transpose:
return indices, parent
return parent, indices
def main():
rows = 20
cols = 21
comparisons = 1000
samples = 10**5
optax_hungarian_algorithm = jax.jit(optax.assignment.hungarian_algorithm)
key = random.key(0)
cost_matrix = jnp.empty((rows, cols))
optax_jaxpr = jax.make_jaxpr(optax_hungarian_algorithm)(cost_matrix)
scenic_jaxpr = jax.make_jaxpr(scenic_hungarian_algorithm)(cost_matrix)
print(f"jaxpr for optax version:\n{optax_jaxpr}")
print(f"jaxpr for scenic version:\n{scenic_jaxpr}")
for _ in tqdm.trange(comparisons):
key, subkey = random.split(key)
cost_matrix = random.normal(subkey, (rows, cols))
optax_sol = optax_hungarian_algorithm(cost_matrix)
optax_cost = cost_matrix[optax_sol].sum()
scenic_sol = scenic_hungarian_algorithm(cost_matrix)
scenic_cost = cost_matrix[scenic_sol].sum()
assert jnp.isclose(optax_cost, scenic_cost)
cost_matrices = random.normal(key, (samples, rows, cols))
print(f"running on a batch of cost matrices of size {cost_matrices.shape}")
start = default_timer()
output = lax.map(optax_hungarian_algorithm, cost_matrices)
output = jax.block_until_ready(output)
end = default_timer()
print(f"lax.map runtime for optax version: {end - start:g}")
start = default_timer()
output = lax.map(scenic_hungarian_algorithm, cost_matrices)
output = jax.block_until_ready(output)
end = default_timer()
print(f"lax.map runtime for scenic version: {end - start:g}")
start = default_timer()
output = jax.vmap(optax_hungarian_algorithm)(cost_matrices)
output = jax.block_until_ready(output)
end = default_timer()
print(f"vmap runtime for optax version: {end - start:g}")
start = default_timer()
output = jax.vmap(scenic_hungarian_algorithm)(cost_matrices)
output = jax.block_until_ready(output)
end = default_timer()
print(f"vmap runtime for scenic version: {end - start:g}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment