This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import os | |
os.environ["MTL_CAPTURE_ENABLED"]="1" | |
a = torch.ones(2, (1 << 31) + 5, dtype=torch.int8, device='mps') | |
index_0 = torch.tensor([0, -1, 0, 1], device=a.device) | |
index_1 = torch.tensor([-2, -1, 0, 1], device=a.device) | |
values = torch.tensor([12, 13, 10, 11], dtype=a.dtype, device=a.device) | |
with torch.mps.profiler.metal_capture("index_put"): | |
a.index_put_((index_0, index_1), values, accumulate=True) | |
b = a[1, -2].cpu() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys | |
import subprocess | |
import urllib.request | |
import json | |
def get_latest_version(package_name: str) -> str: | |
"""Get latest version from PyPI""" | |
api_url = f"https://pypi.org/pypi/{package_name}/json" |
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
With cudnn-9.10.2.21 | |
``` | |
$ CUDNN_LOGINFO_DBG=3 RUN_SLOW=1 python3 -m pytest -v tests/models/vit/test_modeling_vit.py::ViTModelTest::test_batching_equivalence | |
========================================================================================== test session starts =========================================================================================== | |
platform linux -- Python 3.10.12, pytest-8.4.1, pluggy-1.6.0 -- /home/ubuntu/py3.10-nightly/bin/python3 | |
cachedir: .pytest_cache | |
rootdir: /home/ubuntu/transformers | |
configfile: pyproject.toml | |
plugins: xdist-3.8.0, asyncio-1.1.0, rerunfailures-15.1, order-1.3.0, timeout-2.4.0, rich-0.2.0 | |
asyncio: mode=strict, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.cpp_extension as _ce | |
import tempfile | |
import os | |
import subprocess | |
src = """#include <c10/util/BFloat16.h> | |
#include <iostream> | |
int main() { | |
std::cout << c10::BFloat16(3.14) << std::endl; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
import torch | |
import time | |
def nvrtc_compile(source: str) -> str: | |
from ctypes import CDLL, c_void_p, c_char_p, c_size_t, byref, create_string_buffer | |
libnvrtc = CDLL('libnvrtc.so') | |
def get_error_string() -> str: | |
err_p = c_char_p() | |
libnvrtc.nvrtcGetErrorString(result, byref(err_str)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example showing how to use the no_header mode with a TensorBase CUDA extension | |
This example creates a CUDA extension that directly includes ATen/core/TensorBase.h | |
instead of torch/extension.h, resulting in faster compilation with no_header=True | |
""" | |
from datetime import datetime | |
import torch | |
import torch.utils.cpp_extension |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dis | |
import timeit | |
def list_to_dict_1(l): | |
rc = {} | |
for idx, v in enumerate(l): | |
rc[v] = idx | |
return rc | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Fail with Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" on M1/M2 (using MacOS 15.3.1) | |
// Works on M4 (and may be M3) | |
let shader_source = """ | |
template <typename T> | |
float bessel_j0_forward(T x) { | |
constexpr float PP[] = { | |
+7.96936729297347051624e-04, | |
+8.28352392107440799803e-02, | |
+1.23953371646414299388e+00, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to reuse shared memory | |
# Right now MPS inductor produces following code | |
# #include <c10/metal/random.h> | |
# #include <c10/metal/special_math.h> | |
# #include <c10/metal/utils.h> | |
# #include <c10/metal/reduction_utils.h> | |
# kernel void generated_kernel( | |
# device float* out_ptr0, | |
# device float* out_ptr1, | |
# constant float* in_ptr0, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
let shader_source = """ | |
struct add_functor { | |
template <typename T> | |
inline T operator()(const T a, const T b) { | |
return static_cast<T>(a + b); | |
} | |
}; | |
namespace { | |
struct sub_functor { |
NewerOlder