Skip to content

Instantly share code, notes, and snippets.

@carlosgmartin
Last active September 19, 2024 08:31
Show Gist options
  • Save carlosgmartin/1fd4e60bed526ec8ae076137ded6ebab to your computer and use it in GitHub Desktop.
Save carlosgmartin/1fd4e60bed526ec8ae076137ded6ebab to your computer and use it in GitHub Desktop.
JAX implementation of the Roberts low-discrepancy sequence
import argparse
import jax
from jax import lax, numpy as jnp, random
from matplotlib import pyplot as plt, rcParams
def newton_raphson(f, x, iters):
"""Use the Newton-Raphson method to find a root of the given function."""
def update(x, _):
y = x - f(x) / jax.grad(f)(x)
return y, None
x, _ = lax.scan(update, 1.0, length=iters)
return x
def roberts_sequence(
num_points,
dim,
root_iters=10_000,
complement_basis=True,
key=None,
perturb=False,
shuffle=False,
):
"""Returns the Roberts sequence, a low-discrepancy sequence:
extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences
Low-discrepancy sequences are useful for quasi-Monte Carlo methods.
Args:
num_points: Number of points to return.
dim: The dimensionality of the sequence.
root_iters: Number of iterations to use to find the root.
complement_basis: Use complement of the basis for higher precision, as
described in https://www.martysmods.com/a-better-r2-sequence.
key: a PRNG key.
perturb: Apply a uniformly random perturbation to the entire sequence,
followed by modulo 1.
shuffle: Shuffle the elements of the sequence before returning them.
Warning: This degrades the low-discrepancy property for prefixes of
the output sequence.
Returns:
An array of shape (num_points, dim) containing the sequence.
"""
def f(x):
return x ** (dim + 1) - x - 1
# Compute the unique positive root of f using the Newton-Raphson method.
root = newton_raphson(f, 1.0, root_iters)
# assert root > 0
# assert jnp.isclose(f(root), 0, atol=1e-6)
basis = 1 / root ** (1 + jnp.arange(dim))
if complement_basis:
basis = 1 - basis
n = jnp.arange(num_points)
x = n[:, None] * basis[None, :]
if perturb:
if key is None:
raise ValueError("key cannot be None when perturb=True")
key, subkey = random.split(key)
x += random.uniform(subkey, [dim])
x, _ = jnp.modf(x)
if shuffle:
if key is None:
raise ValueError("key cannot be None when shuffle=True")
x = random.permutation(key, x)
return x
def cumulative_min(x):
def f(h, x):
y = jnp.minimum(h, x)
return y, y
_, y = lax.scan(f, jnp.inf, x)
return y
def min_distances(points):
distances = jnp.linalg.norm(points[:, None] - points[None, :], axis=-1)
num_points, _ = points.shape
i, j = jnp.indices([num_points, num_points])
dists = distances.min(-1, where=(j < i), initial=jnp.inf)
return cumulative_min(dists)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--seed", type=int, default=0)
p.add_argument("--points", type=int, default=30_000)
p.add_argument("--markersize", type=float, default=1)
return p.parse_args()
def main():
args = parse_args()
key = random.key(args.seed)
points_labels = [
(random.uniform(key, [args.points, 2]), "random"),
(roberts_sequence(args.points, 2), "Roberts sequence"),
(
roberts_sequence(args.points, 2, key=key, perturb=True),
"Roberts sequence (perturbed)",
),
]
_, axs = plt.subplots(
ncols=len(points_labels),
figsize=plt.figaspect(1 / len(points_labels)),
constrained_layout=True,
)
for ax in axs:
ax.set(aspect="equal")
_, ax_dist = plt.subplots(constrained_layout=True)
for i, (ax, (points, label)) in enumerate(zip(axs, points_labels)):
ax.set_title(label)
color = f"C{i}"
ax.scatter(
*points.T,
s=args.markersize,
facecolor=color,
edgecolor="none",
)
ax_dist.plot(min_distances(points), label=label, color=color)
n = jnp.arange(args.points)
ax_dist.plot(
n,
1 / n,
linestyle="--",
label="$1/n$",
color="black",
)
ax_dist.plot(
n,
1 / n**0.5,
linestyle="--",
label="$1/\\sqrt{n}$",
color="crimson",
)
ax_dist.set(
xscale="log",
yscale="log",
ylabel="distance between closest pair of points",
xlabel="number of points ($n$)",
)
ax_dist.legend()
rcParams["savefig.dpi"] = 300
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment