Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Last active April 27, 2025 05:20
Show Gist options
  • Save carlosgmartin/b32fa6fed3aa82f83dfbaac4b6345672 to your computer and use it in GitHub Desktop.
Save carlosgmartin/b32fa6fed3aa82f83dfbaac4b6345672 to your computer and use it in GitHub Desktop.
Implementation of saturating arithmetic for JAX
"""
https://github.com/jax-ml/jax/issues/26566
"""
import itertools
import operator
import jax
from jax import numpy as jnp
from tqdm import tqdm
def sat_add(a: jax.Array, b: jax.Array):
"""
https://en.cppreference.com/w/cpp/numeric/add_sat
"""
if a.dtype != b.dtype:
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.")
dtype = a.dtype
max_int = jnp.array(jnp.iinfo(dtype).max, dtype)
min_int = jnp.array(jnp.iinfo(dtype).min, dtype)
c = a + b
overflow = (a > 0) & (b > max_int - a)
c = jnp.where(overflow, max_int, c)
underflow = (a < 0) & (b < min_int - a)
c = jnp.where(underflow, min_int, c)
return c
def sat_sub(a: jax.Array, b: jax.Array):
"""
https://en.cppreference.com/w/cpp/numeric/sub_sat
"""
if a.dtype != b.dtype:
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.")
dtype = a.dtype
max_int = jnp.array(jnp.iinfo(dtype).max, dtype)
min_int = jnp.array(jnp.iinfo(dtype).min, dtype)
c = a - b
overflow = (b < 0) & (a > max_int + b)
c = jnp.where(overflow, max_int, c)
underflow = (b > 0) & (a < min_int + b)
c = jnp.where(underflow, min_int, c)
return c
def sat_mul(a: jax.Array, b: jax.Array):
"""
https://en.cppreference.com/w/cpp/numeric/mul_sat
"""
if a.dtype != b.dtype:
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.")
dtype = a.dtype
max_int = jnp.array(jnp.iinfo(dtype).max, dtype)
min_int = jnp.array(jnp.iinfo(dtype).min, dtype)
c = a * b
q, r = jnp.divmod(max_int, b)
overflow = (a > 0) & (b > 0) & (a > q)
overflow |= (a < 0) & (b < 0) & ((a < q) | ((a == q) & (r < 0)))
c = jnp.where(overflow, max_int, c)
q, r = jnp.divmod(min_int, b)
underflow = (a < 0) & (b > 0) & ((a < q) | ((a == q) & (r > 0)))
underflow |= (a > 0) & (b < 0) & (a > q) & (b != -1)
# Special case: -128 / (-1) = 128 > 127 = max int for int8, yields -128 instead.
c = jnp.where(underflow, min_int, c)
return c
def sat_div(a: jax.Array, b: jax.Array):
"""
https://en.cppreference.com/w/cpp/numeric/div_sat
https://www.gnu.org/software/c-intro-and-ref/manual/html_node/Division-and-Remainder.html
"Integer division overflows in one specific case..."
"""
if a.dtype != b.dtype:
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.")
dtype = a.dtype
max_int = jnp.array(jnp.iinfo(dtype).max, dtype)
min_int = jnp.array(jnp.iinfo(dtype).min, dtype)
return jnp.where(
jnp.issubdtype(dtype, jnp.signedinteger) & (a == min_int) & (b == -1),
max_int,
a // b,
)
def sat_cast(a: jax.Array, dtype, debug=False):
"""
https://en.cppreference.com/w/cpp/numeric/saturate_cast
"""
return a.clip(
min=max(jnp.iinfo(a.dtype).min, jnp.iinfo(dtype).min),
max=min(jnp.iinfo(a.dtype).max, jnp.iinfo(dtype).max),
)
def sat_neg(a: jax.Array):
dtype = a.dtype
if jnp.issubdtype(dtype, jnp.unsignedinteger):
return jnp.array(0, dtype)
max_int = jnp.array(jnp.iinfo(dtype).max, dtype)
min_int = jnp.array(jnp.iinfo(dtype).min, dtype)
a = jnp.where(
jnp.issubdtype(dtype, jnp.signedinteger) & (a == min_int),
max_int,
-a,
)
return a
def clamp(value, min_value, max_value):
if value < min_value:
return min_value
if value > max_value:
return max_value
return value
def sat_spec(dtype, op, args):
args = [a.item() if isinstance(a, jax.Array) else a for a in args]
output = op(*args) # carry out operation under infinite precision
max_int = jnp.iinfo(dtype).max
min_int = jnp.iinfo(dtype).min
return jnp.array(clamp(output, min_int, max_int), dtype)
def get_values(dtype):
min_int = jnp.iinfo(dtype).min
max_int = jnp.iinfo(dtype).max
return jnp.arange(min_int, max_int + 1, dtype=dtype)
def main():
dtypes = [jnp.int8, jnp.uint8]
for value_dtype, target_dtype in itertools.product(dtypes, repeat=2):
print(f"{value_dtype=} {target_dtype=}")
for value in tqdm(get_values(value_dtype)):
expected = sat_spec(target_dtype, lambda x: x, [value])
output = sat_cast(value, target_dtype)
if output != expected:
sat_cast(value, target_dtype, debug=True)
print(f"{value}\n{target_dtype.__name__}\n{expected=}\n{output=}")
breakpoint()
tests = [
(operator.neg, sat_neg, 1),
(operator.add, sat_add, 2),
(operator.sub, sat_sub, 2),
(operator.mul, sat_mul, 2),
(operator.floordiv, sat_div, 2),
]
for op, sat_op, arity in tests:
sat_op = jax.jit(sat_op)
for dtype in dtypes:
print(f"{op=} {dtype=}")
for args in tqdm(list(itertools.product(get_values(dtype), repeat=arity))):
try:
expected = sat_spec(dtype, op, args)
except ZeroDivisionError:
continue
output = sat_op(*args)
if output != expected:
print(f"{dtype=}\n{args=}\n{output=}\n{expected=}")
breakpoint()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment