Last active
September 19, 2024 08:31
-
-
Save carlosgmartin/1fd4e60bed526ec8ae076137ded6ebab to your computer and use it in GitHub Desktop.
JAX implementation of the Roberts low-discrepancy sequence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import jax | |
from jax import lax, numpy as jnp, random | |
from matplotlib import pyplot as plt, rcParams | |
def newton_raphson(f, x, iters): | |
"""Use the Newton-Raphson method to find a root of the given function.""" | |
def update(x, _): | |
y = x - f(x) / jax.grad(f)(x) | |
return y, None | |
x, _ = lax.scan(update, 1.0, length=iters) | |
return x | |
def roberts_sequence( | |
num_points, | |
dim, | |
root_iters=10_000, | |
complement_basis=True, | |
key=None, | |
perturb=False, | |
shuffle=False, | |
): | |
"""Returns the Roberts sequence, a low-discrepancy sequence: | |
extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences | |
Low-discrepancy sequences are useful for quasi-Monte Carlo methods. | |
Args: | |
num_points: Number of points to return. | |
dim: The dimensionality of the sequence. | |
root_iters: Number of iterations to use to find the root. | |
complement_basis: Use complement of the basis for higher precision, as | |
described in https://www.martysmods.com/a-better-r2-sequence. | |
key: a PRNG key. | |
perturb: Apply a uniformly random perturbation to the entire sequence, | |
followed by modulo 1. | |
shuffle: Shuffle the elements of the sequence before returning them. | |
Warning: This degrades the low-discrepancy property for prefixes of | |
the output sequence. | |
Returns: | |
An array of shape (num_points, dim) containing the sequence. | |
""" | |
def f(x): | |
return x ** (dim + 1) - x - 1 | |
# Compute the unique positive root of f using the Newton-Raphson method. | |
root = newton_raphson(f, 1.0, root_iters) | |
# assert root > 0 | |
# assert jnp.isclose(f(root), 0, atol=1e-6) | |
basis = 1 / root ** (1 + jnp.arange(dim)) | |
if complement_basis: | |
basis = 1 - basis | |
n = jnp.arange(num_points) | |
x = n[:, None] * basis[None, :] | |
if perturb: | |
if key is None: | |
raise ValueError("key cannot be None when perturb=True") | |
key, subkey = random.split(key) | |
x += random.uniform(subkey, [dim]) | |
x, _ = jnp.modf(x) | |
if shuffle: | |
if key is None: | |
raise ValueError("key cannot be None when shuffle=True") | |
x = random.permutation(key, x) | |
return x | |
def cumulative_min(x): | |
def f(h, x): | |
y = jnp.minimum(h, x) | |
return y, y | |
_, y = lax.scan(f, jnp.inf, x) | |
return y | |
def min_distances(points): | |
distances = jnp.linalg.norm(points[:, None] - points[None, :], axis=-1) | |
num_points, _ = points.shape | |
i, j = jnp.indices([num_points, num_points]) | |
dists = distances.min(-1, where=(j < i), initial=jnp.inf) | |
return cumulative_min(dists) | |
def parse_args(): | |
p = argparse.ArgumentParser() | |
p.add_argument("--seed", type=int, default=0) | |
p.add_argument("--points", type=int, default=30_000) | |
p.add_argument("--markersize", type=float, default=1) | |
return p.parse_args() | |
def main(): | |
args = parse_args() | |
key = random.key(args.seed) | |
points_labels = [ | |
(random.uniform(key, [args.points, 2]), "random"), | |
(roberts_sequence(args.points, 2), "Roberts sequence"), | |
( | |
roberts_sequence(args.points, 2, key=key, perturb=True), | |
"Roberts sequence (perturbed)", | |
), | |
] | |
_, axs = plt.subplots( | |
ncols=len(points_labels), | |
figsize=plt.figaspect(1 / len(points_labels)), | |
constrained_layout=True, | |
) | |
for ax in axs: | |
ax.set(aspect="equal") | |
_, ax_dist = plt.subplots(constrained_layout=True) | |
for i, (ax, (points, label)) in enumerate(zip(axs, points_labels)): | |
ax.set_title(label) | |
color = f"C{i}" | |
ax.scatter( | |
*points.T, | |
s=args.markersize, | |
facecolor=color, | |
edgecolor="none", | |
) | |
ax_dist.plot(min_distances(points), label=label, color=color) | |
n = jnp.arange(args.points) | |
ax_dist.plot( | |
n, | |
1 / n, | |
linestyle="--", | |
label="$1/n$", | |
color="black", | |
) | |
ax_dist.plot( | |
n, | |
1 / n**0.5, | |
linestyle="--", | |
label="$1/\\sqrt{n}$", | |
color="crimson", | |
) | |
ax_dist.set( | |
xscale="log", | |
yscale="log", | |
ylabel="distance between closest pair of points", | |
xlabel="number of points ($n$)", | |
) | |
ax_dist.legend() | |
rcParams["savefig.dpi"] = 300 | |
plt.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment