This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
https://github.com/jax-ml/jax/issues/26566 | |
""" | |
import itertools | |
import operator | |
import jax | |
from jax import numpy as jnp | |
from tqdm import tqdm |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Improving the lowering and compilation of unrolled lax.scan loops | |
https://github.com/jax-ml/jax/discussions/25336 | |
""" | |
import argparse | |
import functools | |
import jax | |
from jax import lax, numpy as jnp, random |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from timeit import default_timer | |
import jax | |
import optax | |
import tqdm | |
from jax import lax, numpy as jnp, random | |
@jax.jit | |
def scenic_hungarian_algorithm(cost): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import ale_py | |
import cv2 | |
import gymnasium as gym | |
import minari | |
import minigrid | |
import numpy as np | |
import sb3_contrib | |
import shortuuid |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import jax | |
from jax import lax, numpy as jnp, random | |
from matplotlib import pyplot as plt, rcParams | |
def newton_raphson(f, x, iters): | |
"""Use the Newton-Raphson method to find a root of the given function.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
from jax import numpy as jnp | |
from jax.flatten_util import ravel_pytree | |
def centered_finite_differences(f, x, eps): | |
x, unravel = ravel_pytree(x) | |
def helper(index): | |
xp = x.at[index].add(eps) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import jax | |
import optax | |
from jax import numpy as jnp, random | |
def pseudo_gradient( | |
f, | |
x, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import gurobipy | |
import cvxopt | |
from multiprocessing import Pool | |
from timeit import default_timer as timer | |
gurobipy.setParam('OutputFlag', 0) | |
def max_min_gurobi(game): | |
model = gurobipy.Model() | |
v = model.addMVar(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import sympy.ntheory | |
import time | |
# https://mathoverflow.net/a/227408/74578 | |
def mobius_transform(sequence): | |
sequence = sequence.copy() | |
for i in range(1, len(sequence)//2+1): | |
sequence[i+i-1::i] -= sequence[i-1] | |
return sequence |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://www.sciencedirect.com/science/article/pii/0304397588900461 | |
def stackermann(i, n): | |
stack = [] | |
stack.append(i) | |
stack.append(n) | |
while len(stack) > 1: | |
n = stack.pop() | |
i = stack.pop() | |
if i == 0: |
NewerOlder