Created
May 1, 2025 18:07
-
-
Save brianwisti/6d399fdcc3ef6dccb5f958ef950049ed to your computer and use it in GitHub Desktop.
Prime sieve comparison in plain Python, Numba, and Taichi
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Count the prime numbers in the range [1, n]. | |
Comparing the same code in pure python, Numba, and Taichi. This is the sort | |
of thing that happens when I have coffee and time on my hands. | |
while looking at the `taichi docs`_. | |
.. _Taichi docs: https://docs.taichi-lang.org/docs/accelerate_python | |
""" | |
import logging | |
import timeit | |
from collections.abc import Callable | |
import taichi as ti | |
from numba import njit | |
from rich.console import Console | |
from rich.logging import RichHandler | |
from rich.table import Table | |
# Uncomment for prettier screenshots. | |
# import warnings | |
# warnings.simplefilter("ignore") | |
N = 1_000_000 | |
ITERATIONS = 100 | |
SieveFunction = Callable[[int, int], int] | |
ApproachResult = tuple[str, str, str, str] | |
logging.basicConfig(level=logging.INFO, handlers=[RichHandler()]) | |
logger = logging.getLogger(__name__) | |
def is_prime(n: int) -> bool: | |
"""Check if a number is prime.""" | |
result = True | |
for k in range(2, int(n**0.5) + 1): | |
if n % k == 0: | |
result = False | |
break | |
return result | |
def count_primes(n: int) -> int: | |
"""Count the primes between 2 and n.""" | |
count = 0 | |
for k in range(2, n): | |
if is_prime(k): | |
count += 1 | |
return count | |
@njit | |
def is_prime_numba(n: int) -> bool: | |
"""Use numba to see if a number is prime.""" | |
result = True | |
for k in range(2, int(n**0.5) + 1): | |
if n % k == 0: | |
result = False | |
break | |
return result | |
@njit | |
def count_primes_numba(n: int) -> int: | |
"""Use numba to count the number of primes.""" | |
count = 0 | |
for k in range(2, n): | |
if is_prime_numba(k): | |
count += 1 | |
return count | |
# `arch=ti.gpu` averages .04s slower on my M1 MBP w/32GB RAM. | |
# Could probably tweak that by shifting more RAM to GPU, | |
# but not on first coffee. | |
ti.init(arch=ti.cpu) | |
@ti.func | |
def is_prime_ti(n: int) -> bool: | |
"""Use taichi to see if a number is prime.""" | |
result = True | |
# Sieve of Erastothones | |
for k in range(2, int(n**0.5) + 1): | |
if n % k == 0: | |
result = False | |
break | |
return result | |
@ti.kernel | |
def count_primes_ti(n: int) -> int: | |
"""Use taichi to count the number of primes.""" | |
count = 0 | |
for k in range(2, n): | |
if is_prime_ti(k): | |
count += 1 | |
return count | |
def use_approach(approach: SieveFunction, n: int, iterations: int) -> ApproachResult: | |
"""Report the results of using given prime sieve approach.""" | |
approach_name = approach.__name__ | |
result = approach(n) | |
total_time = timeit.timeit(lambda: approach(n), number=iterations) | |
mean_time = total_time / iterations | |
logger.debug( | |
"use approach '%s': result=%d; total time=%f; mean time=%f", | |
approach_name, | |
result, | |
total_time, | |
mean_time, | |
) | |
return approach_name, f"{result:,}", f"{total_time:.3f}", f"{mean_time:.3f}" | |
def result_report(n: int, iterations: int, results: list[ApproachResult]) -> Table: | |
"""Time various approaches to the prime sieve.""" | |
table = Table( | |
title=f"Prime Sieve Timing(n={n:,}, {iterations:,} iterations)", | |
expand=True, | |
) | |
table.add_column("Approach") | |
table.add_column("Result", justify="right") | |
table.add_column("Total Time", justify="right") | |
table.add_column("Mean Time", justify="right") | |
for result in results: | |
table.add_row(*result) | |
return table | |
def main(): | |
"""Compare stylistic approaches to a prime sieve.""" | |
logger.debug("starting prime sieve check: n=%d; iterations=%d", N, ITERATIONS) | |
results = [ | |
use_approach(approach, N, ITERATIONS) | |
for approach in (count_primes, count_primes_numba, count_primes_ti) | |
] | |
report = result_report(N, ITERATIONS, results) | |
console = Console() | |
console.print(report) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment