Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Last active March 21, 2025 00:31
Show Gist options
  • Save msaroufim/079a8d08ffebd0f91a1c2247eb0ce9e0 to your computer and use it in GitHub Desktop.
Save msaroufim/079a8d08ffebd0f91a1c2247eb0ce9e0 to your computer and use it in GitHub Desktop.
"""
Minimal example that:
- Only includes <ATen/core/Tensor.h> (for at::Tensor)
and <ATen/Functions.h> (for at::empty).
- Avoids <torch/extension.h> or <torch/types.h>.
- Uses <torch/csrc/utils/pybind.h> so PyBind can cast torch.Tensor <-> at::Tensor.
- Demonstrates a custom CUDA kernel that adds x + y + 1.
- Uses no_implicit_headers=True to reduce compile overhead.
"""
import os
import shutil
from datetime import datetime
import torch
import torch.utils.cpp_extension
# If you have a custom CUDA include path, set it here:
cuda_include_dir = os.path.join(os.environ["HOME"], ".conda/envs/pt/targets/x86_64-linux/include")
BUILD_DIR = os.path.join(os.getcwd(), "minimal_tensor_build")
if os.path.exists(BUILD_DIR):
print(f"Removing build directory: {BUILD_DIR}")
shutil.rmtree(BUILD_DIR)
os.makedirs(BUILD_DIR, exist_ok=True)
print(f"Created build directory: {BUILD_DIR}")
# --------------------------------------------------------------------------
# C++ source: minimal includes
# --------------------------------------------------------------------------
cpp_source = r"""
#include <ATen/core/Tensor.h> // at::Tensor
#include <ATen/Functions.h> // at::empty(...) and other creation ops
#include <c10/cuda/CUDAGuard.h> // at::cuda::CUDAGuard
#include <pybind11/pybind11.h> // pybind11
#include <torch/csrc/utils/pybind.h> // Allows torch.Tensor <-> at::Tensor casting
// Forward-declare our CUDA kernel launcher
void launch_add_kernel(const float* x_data,
const float* y_data,
float* out_data,
int64_t num_elements);
// Simple function: x + y + 1
at::Tensor tensor_add_cpp(const at::Tensor& x, const at::Tensor& y) {
// Basic checks
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(y.is_cuda(), "y must be a CUDA tensor");
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, "x must be float32");
TORCH_CHECK(y.scalar_type() == at::ScalarType::Float, "y must be float32");
TORCH_CHECK(x.sizes() == y.sizes(), "x and y must have the same shape");
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(), "x and y must be contiguous tensors");
// Create output on the same device & dtype as x
auto out = at::empty(x.sizes(), x.options());
// Ensure we're on the correct device and call our kernel
at::cuda::CUDAGuard device_guard(x.device());
launch_add_kernel(x.data_ptr<float>(),
y.data_ptr<float>(),
out.data_ptr<float>(),
x.numel());
return out;
}
// pybind11 module definition
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("tensor_add_cpp",
&tensor_add_cpp,
"Add x + y + 1 (float32 CUDA).");
}
"""
# --------------------------------------------------------------------------
# CUDA source
# --------------------------------------------------------------------------
cuda_source = r"""
#include <cuda_runtime.h>
__global__ void add_kernel(const float* x, const float* y, float* out, int64_t size) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
out[idx] = x[idx] + y[idx] + 1.0f;
}
}
void launch_add_kernel(const float* x_data,
const float* y_data,
float* out_data,
int64_t num_elements) {
const int threads = 256;
const int blocks = (num_elements + threads - 1) / threads;
add_kernel<<<blocks, threads>>>(x_data, y_data, out_data, num_elements);
}
"""
def main():
# Build the extension with minimal includes
start_time = datetime.now()
module = torch.utils.cpp_extension.load_inline(
name="minimal_tensor_extension",
cpp_sources=cpp_source,
cuda_sources=cuda_source,
verbose=True,
no_implicit_headers=False, # Avoid heavy auto-includes
extra_include_paths=[cuda_include_dir],
build_directory=BUILD_DIR,
extra_cuda_cflags=["-arch=sm_80"],
)
total_time = datetime.now() - start_time
print(f"\nExtension compiled in {total_time}!\n")
if not torch.cuda.is_available():
print("CUDA is not available. Exiting.")
return
# Create test tensors
x = torch.randn(10, device="cuda", dtype=torch.float32)
y = torch.randn(10, device="cuda", dtype=torch.float32)
# Call our custom function
result = module.tensor_add_cpp(x, y)
# Check correctness (our kernel does x + y + 1)
expected = x + y + 1.0
max_diff = (result - expected).abs().max()
print(f"Max difference: {max_diff.item()}")
if torch.allclose(result, expected):
print("Test PASSED! ✓")
else:
print("Test FAILED!")
if __name__ == "__main__":
main()
@msaroufim
Copy link
Author

(pt) ➜ examples git:(msaroufim/noheader) ✗ python tensor_base_example.py
Removing build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Created build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Detected CUDA files, patching ldflags
Emitting ninja build file /home/marksaroufim/pytorch/examples/minimal_tensor_build/build.ninja...
Building extension module minimal_tensor_extension...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /home/marksaroufim/.conda/envs/pt/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS -D__CUDA_NO_HALF_CONVERSIONS_ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -arch=sm_80 -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/cuda.cu -o cuda.cuda.o
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
[2/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/main.cpp -o main.o
[3/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ main.o cuda.cuda.o -shared -L/home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/home/marksaroufim/.conda/envs/pt/lib -lcudart -o minimal_tensor_extension.so
Loading extension module minimal_tensor_extension...

Extension compiled in 0:00:14.889961!

Max difference: 0.0
Test PASSED! ✓
(pt) ➜ examples git:(msaroufim/noheader) ✗ python tensor_base_example.py
Removing build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Created build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Detected CUDA files, patching ldflags
Emitting ninja build file /home/marksaroufim/pytorch/examples/minimal_tensor_build/build.ninja...
Building extension module minimal_tensor_extension...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/main.cpp -o main.o
[2/3] /home/marksaroufim/.conda/envs/pt/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS -D__CUDA_NO_HALF_CONVERSIONS
-D__CUDA_NO_BFLOAT16_CONVERSIONS
-D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -arch=sm_80 -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/cuda.cu -o cuda.cuda.o
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
[3/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ main.o cuda.cuda.o -shared -L/home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/home/marksaroufim/.conda/envs/pt/lib -lcudart -o minimal_tensor_extension.so
Loading extension module minimal_tensor_extension...

Extension compiled in 0:01:04.020993!

Max difference: 0.0
Test PASSED! ✓
(pt) ➜ examples git:(msaroufim/noheader) ✗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment