Skip to content

Instantly share code, notes, and snippets.

@brianwisti
Created May 1, 2025 18:07
Show Gist options
  • Save brianwisti/6d399fdcc3ef6dccb5f958ef950049ed to your computer and use it in GitHub Desktop.
Save brianwisti/6d399fdcc3ef6dccb5f958ef950049ed to your computer and use it in GitHub Desktop.
Prime sieve comparison in plain Python, Numba, and Taichi
"""
Count the prime numbers in the range [1, n].
Comparing the same code in pure python, Numba, and Taichi. This is the sort
of thing that happens when I have coffee and time on my hands.
while looking at the `taichi docs`_.
.. _Taichi docs: https://docs.taichi-lang.org/docs/accelerate_python
"""
import logging
import timeit
from collections.abc import Callable
import taichi as ti
from numba import njit
from rich.console import Console
from rich.logging import RichHandler
from rich.table import Table
# Uncomment for prettier screenshots.
# import warnings
# warnings.simplefilter("ignore")
N = 1_000_000
ITERATIONS = 100
SieveFunction = Callable[[int, int], int]
ApproachResult = tuple[str, str, str, str]
logging.basicConfig(level=logging.INFO, handlers=[RichHandler()])
logger = logging.getLogger(__name__)
def is_prime(n: int) -> bool:
"""Check if a number is prime."""
result = True
for k in range(2, int(n**0.5) + 1):
if n % k == 0:
result = False
break
return result
def count_primes(n: int) -> int:
"""Count the primes between 2 and n."""
count = 0
for k in range(2, n):
if is_prime(k):
count += 1
return count
@njit
def is_prime_numba(n: int) -> bool:
"""Use numba to see if a number is prime."""
result = True
for k in range(2, int(n**0.5) + 1):
if n % k == 0:
result = False
break
return result
@njit
def count_primes_numba(n: int) -> int:
"""Use numba to count the number of primes."""
count = 0
for k in range(2, n):
if is_prime_numba(k):
count += 1
return count
# `arch=ti.gpu` averages .04s slower on my M1 MBP w/32GB RAM.
# Could probably tweak that by shifting more RAM to GPU,
# but not on first coffee.
ti.init(arch=ti.cpu)
@ti.func
def is_prime_ti(n: int) -> bool:
"""Use taichi to see if a number is prime."""
result = True
# Sieve of Erastothones
for k in range(2, int(n**0.5) + 1):
if n % k == 0:
result = False
break
return result
@ti.kernel
def count_primes_ti(n: int) -> int:
"""Use taichi to count the number of primes."""
count = 0
for k in range(2, n):
if is_prime_ti(k):
count += 1
return count
def use_approach(approach: SieveFunction, n: int, iterations: int) -> ApproachResult:
"""Report the results of using given prime sieve approach."""
approach_name = approach.__name__
result = approach(n)
total_time = timeit.timeit(lambda: approach(n), number=iterations)
mean_time = total_time / iterations
logger.debug(
"use approach '%s': result=%d; total time=%f; mean time=%f",
approach_name,
result,
total_time,
mean_time,
)
return approach_name, f"{result:,}", f"{total_time:.3f}", f"{mean_time:.3f}"
def result_report(n: int, iterations: int, results: list[ApproachResult]) -> Table:
"""Time various approaches to the prime sieve."""
table = Table(
title=f"Prime Sieve Timing(n={n:,}, {iterations:,} iterations)",
expand=True,
)
table.add_column("Approach")
table.add_column("Result", justify="right")
table.add_column("Total Time", justify="right")
table.add_column("Mean Time", justify="right")
for result in results:
table.add_row(*result)
return table
def main():
"""Compare stylistic approaches to a prime sieve."""
logger.debug("starting prime sieve check: n=%d; iterations=%d", N, ITERATIONS)
results = [
use_approach(approach, N, ITERATIONS)
for approach in (count_primes, count_primes_numba, count_primes_ti)
]
report = result_report(N, ITERATIONS, results)
console = Console()
console.print(report)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment