Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
carlosgmartin / jax_saturating_arithmetic.py
Last active April 27, 2025 05:20
Implementation of saturating arithmetic for JAX
"""
https://github.com/jax-ml/jax/issues/26566
"""
import itertools
import operator
import jax
from jax import numpy as jnp
from tqdm import tqdm
"""
Improving the lowering and compilation of unrolled lax.scan loops
https://github.com/jax-ml/jax/discussions/25336
"""
import argparse
import functools
import jax
from jax import lax, numpy as jnp, random
@carlosgmartin
carlosgmartin / hungarian_algorithm.py
Last active November 20, 2024 04:46
Comparison of JAX implementations of the Hungarian algorithm
from timeit import default_timer
import jax
import optax
import tqdm
from jax import lax, numpy as jnp, random
@jax.jit
def scenic_hungarian_algorithm(cost):
import argparse
import ale_py
import cv2
import gymnasium as gym
import minari
import minigrid
import numpy as np
import sb3_contrib
import shortuuid
@carlosgmartin
carlosgmartin / quasi_random.py
Last active September 19, 2024 08:31
JAX implementation of the Roberts low-discrepancy sequence
import argparse
import jax
from jax import lax, numpy as jnp, random
from matplotlib import pyplot as plt, rcParams
def newton_raphson(f, x, iters):
"""Use the Newton-Raphson method to find a root of the given function."""
@carlosgmartin
carlosgmartin / finite_differences.py
Created September 19, 2024 02:45
JAX implementation of finite differences (forward and central)
import jax
from jax import numpy as jnp
from jax.flatten_util import ravel_pytree
def centered_finite_differences(f, x, eps):
x, unravel = ravel_pytree(x)
def helper(index):
xp = x.at[index].add(eps)
@carlosgmartin
carlosgmartin / pseudo_gradient.py
Last active September 19, 2024 08:07
Implementation of a pseudo-gradient estimator based on JAX and Optax
from functools import partial
import jax
import optax
from jax import numpy as jnp, random
def pseudo_gradient(
f,
x,
@carlosgmartin
carlosgmartin / nfgs.py
Created June 12, 2020 21:42
Solving batches of two-player zero-sum normal-form games
import numpy as np
import gurobipy
import cvxopt
from multiprocessing import Pool
from timeit import default_timer as timer
gurobipy.setParam('OutputFlag', 0)
def max_min_gurobi(game):
model = gurobipy.Model()
v = model.addMVar(1)
@carlosgmartin
carlosgmartin / mobius.py
Created July 13, 2019 00:08
Computes the Möbius transform of arithmetic sequences
import numpy as np
import sympy.ntheory
import time
# https://mathoverflow.net/a/227408/74578
def mobius_transform(sequence):
sequence = sequence.copy()
for i in range(1, len(sequence)//2+1):
sequence[i+i-1::i] -= sequence[i-1]
return sequence
@carlosgmartin
carlosgmartin / stackermann.py
Created June 20, 2019 07:00
Stack-based implementation of the Ackermann function
# https://www.sciencedirect.com/science/article/pii/0304397588900461
def stackermann(i, n):
stack = []
stack.append(i)
stack.append(n)
while len(stack) > 1:
n = stack.pop()
i = stack.pop()
if i == 0: