Last active
March 21, 2025 00:31
-
-
Save msaroufim/079a8d08ffebd0f91a1c2247eb0ce9e0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Minimal example that: | |
- Only includes <ATen/core/Tensor.h> (for at::Tensor) | |
and <ATen/Functions.h> (for at::empty). | |
- Avoids <torch/extension.h> or <torch/types.h>. | |
- Uses <torch/csrc/utils/pybind.h> so PyBind can cast torch.Tensor <-> at::Tensor. | |
- Demonstrates a custom CUDA kernel that adds x + y + 1. | |
- Uses no_implicit_headers=True to reduce compile overhead. | |
""" | |
import os | |
import shutil | |
from datetime import datetime | |
import torch | |
import torch.utils.cpp_extension | |
# If you have a custom CUDA include path, set it here: | |
cuda_include_dir = os.path.join(os.environ["HOME"], ".conda/envs/pt/targets/x86_64-linux/include") | |
BUILD_DIR = os.path.join(os.getcwd(), "minimal_tensor_build") | |
if os.path.exists(BUILD_DIR): | |
print(f"Removing build directory: {BUILD_DIR}") | |
shutil.rmtree(BUILD_DIR) | |
os.makedirs(BUILD_DIR, exist_ok=True) | |
print(f"Created build directory: {BUILD_DIR}") | |
# -------------------------------------------------------------------------- | |
# C++ source: minimal includes | |
# -------------------------------------------------------------------------- | |
cpp_source = r""" | |
#include <ATen/core/Tensor.h> // at::Tensor | |
#include <ATen/Functions.h> // at::empty(...) and other creation ops | |
#include <c10/cuda/CUDAGuard.h> // at::cuda::CUDAGuard | |
#include <pybind11/pybind11.h> // pybind11 | |
#include <torch/csrc/utils/pybind.h> // Allows torch.Tensor <-> at::Tensor casting | |
// Forward-declare our CUDA kernel launcher | |
void launch_add_kernel(const float* x_data, | |
const float* y_data, | |
float* out_data, | |
int64_t num_elements); | |
// Simple function: x + y + 1 | |
at::Tensor tensor_add_cpp(const at::Tensor& x, const at::Tensor& y) { | |
// Basic checks | |
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); | |
TORCH_CHECK(y.is_cuda(), "y must be a CUDA tensor"); | |
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, "x must be float32"); | |
TORCH_CHECK(y.scalar_type() == at::ScalarType::Float, "y must be float32"); | |
TORCH_CHECK(x.sizes() == y.sizes(), "x and y must have the same shape"); | |
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(), "x and y must be contiguous tensors"); | |
// Create output on the same device & dtype as x | |
auto out = at::empty(x.sizes(), x.options()); | |
// Ensure we're on the correct device and call our kernel | |
at::cuda::CUDAGuard device_guard(x.device()); | |
launch_add_kernel(x.data_ptr<float>(), | |
y.data_ptr<float>(), | |
out.data_ptr<float>(), | |
x.numel()); | |
return out; | |
} | |
// pybind11 module definition | |
namespace py = pybind11; | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("tensor_add_cpp", | |
&tensor_add_cpp, | |
"Add x + y + 1 (float32 CUDA)."); | |
} | |
""" | |
# -------------------------------------------------------------------------- | |
# CUDA source | |
# -------------------------------------------------------------------------- | |
cuda_source = r""" | |
#include <cuda_runtime.h> | |
__global__ void add_kernel(const float* x, const float* y, float* out, int64_t size) { | |
int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (idx < size) { | |
out[idx] = x[idx] + y[idx] + 1.0f; | |
} | |
} | |
void launch_add_kernel(const float* x_data, | |
const float* y_data, | |
float* out_data, | |
int64_t num_elements) { | |
const int threads = 256; | |
const int blocks = (num_elements + threads - 1) / threads; | |
add_kernel<<<blocks, threads>>>(x_data, y_data, out_data, num_elements); | |
} | |
""" | |
def main(): | |
# Build the extension with minimal includes | |
start_time = datetime.now() | |
module = torch.utils.cpp_extension.load_inline( | |
name="minimal_tensor_extension", | |
cpp_sources=cpp_source, | |
cuda_sources=cuda_source, | |
verbose=True, | |
no_implicit_headers=False, # Avoid heavy auto-includes | |
extra_include_paths=[cuda_include_dir], | |
build_directory=BUILD_DIR, | |
extra_cuda_cflags=["-arch=sm_80"], | |
) | |
total_time = datetime.now() - start_time | |
print(f"\nExtension compiled in {total_time}!\n") | |
if not torch.cuda.is_available(): | |
print("CUDA is not available. Exiting.") | |
return | |
# Create test tensors | |
x = torch.randn(10, device="cuda", dtype=torch.float32) | |
y = torch.randn(10, device="cuda", dtype=torch.float32) | |
# Call our custom function | |
result = module.tensor_add_cpp(x, y) | |
# Check correctness (our kernel does x + y + 1) | |
expected = x + y + 1.0 | |
max_diff = (result - expected).abs().max() | |
print(f"Max difference: {max_diff.item()}") | |
if torch.allclose(result, expected): | |
print("Test PASSED! ✓") | |
else: | |
print("Test FAILED!") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
(pt) ➜ examples git:(msaroufim/noheader) ✗ python tensor_base_example.py
Removing build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Created build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Detected CUDA files, patching ldflags
Emitting ninja build file /home/marksaroufim/pytorch/examples/minimal_tensor_build/build.ninja...
Building extension module minimal_tensor_extension...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /home/marksaroufim/.conda/envs/pt/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS -D__CUDA_NO_HALF_CONVERSIONS_ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -arch=sm_80 -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/cuda.cu -o cuda.cuda.o
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
[2/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="_gcc" -DPYBIND11_STDLIB="_libstdcpp" -DPYBIND11_BUILD_ABI="_cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/main.cpp -o main.o
[3/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ main.o cuda.cuda.o -shared -L/home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/home/marksaroufim/.conda/envs/pt/lib -lcudart -o minimal_tensor_extension.so
Loading extension module minimal_tensor_extension...
Extension compiled in 0:00:14.889961!
Max difference: 0.0
Test PASSED! ✓
(pt) ➜ examples git:(msaroufim/noheader) ✗ python tensor_base_example.py
Removing build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Created build directory: /home/marksaroufim/pytorch/examples/minimal_tensor_build
Detected CUDA files, patching ldflags
Emitting ninja build file /home/marksaroufim/pytorch/examples/minimal_tensor_build/build.ninja...
Building extension module minimal_tensor_extension...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/main.cpp -o main.o
[2/3] /home/marksaroufim/.conda/envs/pt/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=minimal_tensor_extension -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE="gcc" -DPYBIND11_STDLIB="libstdcpp" -DPYBIND11_BUILD_ABI="cxxabi1011" -I/home/marksaroufim/.conda/envs/pt/targets/x86_64-linux/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/TH -isystem /home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/include/THC -isystem /home/marksaroufim/.conda/envs/pt/include -isystem /home/marksaroufim/.conda/envs/pt/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS -D__CUDA_NO_HALF_CONVERSIONS -D__CUDA_NO_BFLOAT16_CONVERSIONS -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_90,code=sm_90 --compiler-options '-fPIC' -arch=sm_80 -std=c++17 -c /home/marksaroufim/pytorch/examples/minimal_tensor_build/cuda.cu -o cuda.cuda.o
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used
[3/3] /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-c++ main.o cuda.cuda.o -shared -L/home/marksaroufim/.conda/envs/pt/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda -ltorch -ltorch_python -L/home/marksaroufim/.conda/envs/pt/lib -lcudart -o minimal_tensor_extension.so
Loading extension module minimal_tensor_extension...
Extension compiled in 0:01:04.020993!
Max difference: 0.0
Test PASSED! ✓
(pt) ➜ examples git:(msaroufim/noheader) ✗