Last active
April 27, 2025 05:20
-
-
Save carlosgmartin/b32fa6fed3aa82f83dfbaac4b6345672 to your computer and use it in GitHub Desktop.
Implementation of saturating arithmetic for JAX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
https://github.com/jax-ml/jax/issues/26566 | |
""" | |
import itertools | |
import operator | |
import jax | |
from jax import numpy as jnp | |
from tqdm import tqdm | |
def sat_add(a: jax.Array, b: jax.Array): | |
""" | |
https://en.cppreference.com/w/cpp/numeric/add_sat | |
""" | |
if a.dtype != b.dtype: | |
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.") | |
dtype = a.dtype | |
max_int = jnp.array(jnp.iinfo(dtype).max, dtype) | |
min_int = jnp.array(jnp.iinfo(dtype).min, dtype) | |
c = a + b | |
overflow = (a > 0) & (b > max_int - a) | |
c = jnp.where(overflow, max_int, c) | |
underflow = (a < 0) & (b < min_int - a) | |
c = jnp.where(underflow, min_int, c) | |
return c | |
def sat_sub(a: jax.Array, b: jax.Array): | |
""" | |
https://en.cppreference.com/w/cpp/numeric/sub_sat | |
""" | |
if a.dtype != b.dtype: | |
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.") | |
dtype = a.dtype | |
max_int = jnp.array(jnp.iinfo(dtype).max, dtype) | |
min_int = jnp.array(jnp.iinfo(dtype).min, dtype) | |
c = a - b | |
overflow = (b < 0) & (a > max_int + b) | |
c = jnp.where(overflow, max_int, c) | |
underflow = (b > 0) & (a < min_int + b) | |
c = jnp.where(underflow, min_int, c) | |
return c | |
def sat_mul(a: jax.Array, b: jax.Array): | |
""" | |
https://en.cppreference.com/w/cpp/numeric/mul_sat | |
""" | |
if a.dtype != b.dtype: | |
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.") | |
dtype = a.dtype | |
max_int = jnp.array(jnp.iinfo(dtype).max, dtype) | |
min_int = jnp.array(jnp.iinfo(dtype).min, dtype) | |
c = a * b | |
q, r = jnp.divmod(max_int, b) | |
overflow = (a > 0) & (b > 0) & (a > q) | |
overflow |= (a < 0) & (b < 0) & ((a < q) | ((a == q) & (r < 0))) | |
c = jnp.where(overflow, max_int, c) | |
q, r = jnp.divmod(min_int, b) | |
underflow = (a < 0) & (b > 0) & ((a < q) | ((a == q) & (r > 0))) | |
underflow |= (a > 0) & (b < 0) & (a > q) & (b != -1) | |
# Special case: -128 / (-1) = 128 > 127 = max int for int8, yields -128 instead. | |
c = jnp.where(underflow, min_int, c) | |
return c | |
def sat_div(a: jax.Array, b: jax.Array): | |
""" | |
https://en.cppreference.com/w/cpp/numeric/div_sat | |
https://www.gnu.org/software/c-intro-and-ref/manual/html_node/Division-and-Remainder.html | |
"Integer division overflows in one specific case..." | |
""" | |
if a.dtype != b.dtype: | |
raise TypeError(f"{a.dtype=} and {b.dtype=} don't match.") | |
dtype = a.dtype | |
max_int = jnp.array(jnp.iinfo(dtype).max, dtype) | |
min_int = jnp.array(jnp.iinfo(dtype).min, dtype) | |
return jnp.where( | |
jnp.issubdtype(dtype, jnp.signedinteger) & (a == min_int) & (b == -1), | |
max_int, | |
a // b, | |
) | |
def sat_cast(a: jax.Array, dtype, debug=False): | |
""" | |
https://en.cppreference.com/w/cpp/numeric/saturate_cast | |
""" | |
return a.clip( | |
min=max(jnp.iinfo(a.dtype).min, jnp.iinfo(dtype).min), | |
max=min(jnp.iinfo(a.dtype).max, jnp.iinfo(dtype).max), | |
) | |
def sat_neg(a: jax.Array): | |
dtype = a.dtype | |
if jnp.issubdtype(dtype, jnp.unsignedinteger): | |
return jnp.array(0, dtype) | |
max_int = jnp.array(jnp.iinfo(dtype).max, dtype) | |
min_int = jnp.array(jnp.iinfo(dtype).min, dtype) | |
a = jnp.where( | |
jnp.issubdtype(dtype, jnp.signedinteger) & (a == min_int), | |
max_int, | |
-a, | |
) | |
return a | |
def clamp(value, min_value, max_value): | |
if value < min_value: | |
return min_value | |
if value > max_value: | |
return max_value | |
return value | |
def sat_spec(dtype, op, args): | |
args = [a.item() if isinstance(a, jax.Array) else a for a in args] | |
output = op(*args) # carry out operation under infinite precision | |
max_int = jnp.iinfo(dtype).max | |
min_int = jnp.iinfo(dtype).min | |
return jnp.array(clamp(output, min_int, max_int), dtype) | |
def get_values(dtype): | |
min_int = jnp.iinfo(dtype).min | |
max_int = jnp.iinfo(dtype).max | |
return jnp.arange(min_int, max_int + 1, dtype=dtype) | |
def main(): | |
dtypes = [jnp.int8, jnp.uint8] | |
for value_dtype, target_dtype in itertools.product(dtypes, repeat=2): | |
print(f"{value_dtype=} {target_dtype=}") | |
for value in tqdm(get_values(value_dtype)): | |
expected = sat_spec(target_dtype, lambda x: x, [value]) | |
output = sat_cast(value, target_dtype) | |
if output != expected: | |
sat_cast(value, target_dtype, debug=True) | |
print(f"{value}\n{target_dtype.__name__}\n{expected=}\n{output=}") | |
breakpoint() | |
tests = [ | |
(operator.neg, sat_neg, 1), | |
(operator.add, sat_add, 2), | |
(operator.sub, sat_sub, 2), | |
(operator.mul, sat_mul, 2), | |
(operator.floordiv, sat_div, 2), | |
] | |
for op, sat_op, arity in tests: | |
sat_op = jax.jit(sat_op) | |
for dtype in dtypes: | |
print(f"{op=} {dtype=}") | |
for args in tqdm(list(itertools.product(get_values(dtype), repeat=arity))): | |
try: | |
expected = sat_spec(dtype, op, args) | |
except ZeroDivisionError: | |
continue | |
output = sat_op(*args) | |
if output != expected: | |
print(f"{dtype=}\n{args=}\n{output=}\n{expected=}") | |
breakpoint() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment