Skip to content

Instantly share code, notes, and snippets.

View maxidl's full-sized avatar

Max Idahl maxidl

  • Hanover, Germany
View GitHub Profile
@maxidl
maxidl / pytorch-nccl-test.sh
Last active August 5, 2025 07:55
SLURM PyTorch NCCL Multi-Node Test Script: A SLURM batch script that tests PyTorch's NCCL functionality across multiple GPU nodes. The script sets up a distributed PyTorch environment using torchrun and runs a comprehensive test that verifies NCCL initialization, inter-process communication barriers, and proper cleanup. Includes diagnostic outpu…
#!/bin/bash
#SBATCH --job-name=pytorch-nccl-test
#SBATCH --partition=
#SBATCH --account=
#SBATCH --qos=
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --gres=gpu:H100:4
#SBATCH --time 0:05:00
@maxidl
maxidl / fp8_cast_bf16.py
Created July 14, 2025 12:48
A version of https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py that uses cpu instead of gpu memory to load and save dequantized weights. Only the dequantization step itself is executed on gpu, with much smaller memory footprint compared to the original script. Runtime is longer, but this enables conversion of fp…
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
@maxidl
maxidl / openai_pyarrow_schemas.py
Created June 16, 2025 16:10
PyArrow Schemas for OpenAI Completion and ChatCompletion
"""
OpenAI Python SDK Compatible Schemas.
Compatible with v1.86.0. Subject to change.
See https://github.com/openai/openai-python for the latest version.
"""
import pyarrow as pa
COMPLETION_SCHEMA = pa.schema([
pa.field('id', pa.string()),
import argparse
import copy
import torch
import datasets as hfds
import transformers
from tqdm.auto import tqdm
import wandb
from functools import partial
import types
import torch
from typing import List, Optional, Tuple, Union, Dict
import transformers
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging as hf_logging
logger = hf_logging.get_logger(__name__)
@maxidl
maxidl / c.py
Created February 10, 2022 21:49
def get_simple_gradient_expl(model, images, targets, absolute=False):
images.requires_grad = True
outputs = model(images)
outputs = outputs.gather(1, targets.unsqueeze(1))
grad = torch.autograd.grad(torch.unbind(outputs), images, create_graph=True)[0] # create_graph=True for second order derivative
expl = grad.abs() if absolute else grad
return expl
@maxidl
maxidl / b.py
Created February 10, 2022 11:25
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from pathlib import Path
from tqdm.auto import tqdm
print(torch.cuda.is_available())
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# dev = torch.device("cpu")
@maxidl
maxidl / a.py
Last active February 9, 2022 13:12
import sys
import os # noqa
sys.path.insert(0, ".") # noqa
import torch
from utils.styled_plot import plt
from utils.dataset import load_test_image, preprocess_image, normalize_image, convert_idx_to_label
from classifiers.cnn_classifier import ImageNetClassifier
from solutions.explainers import plot_attributions, aggregate_attribution, normalize_attribution
@maxidl
maxidl / transformers_integratedgradients_batched.py
Last active February 15, 2021 13:48
generate attributions for transformers using captum, but with batches instead of per instance for higher total throughput
model.to(device)
model.eval()
model.zero_grad()
def forward_func(inputs, attention_mask=None):
return model(inputs, attention_mask=attention_mask).logits
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)
all_input_ids, all_ref_input_ids, all_attributions, all_pred_probs, all_pred_class, all_true_class, all_attr_class, all_attr_score, all_convergence_scores = ([] for i in range(9))
@maxidl
maxidl / imagenet_idx_to_label.json
Last active June 21, 2020 19:04
ImageNet dataset mapping from index to human-readable class label in proper json format.
{
"0": "tench, Tinca tinca",
"1": "goldfish, Carassius auratus",
"2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
"3": "tiger shark, Galeocerdo cuvieri",
"4": "hammerhead, hammerhead shark",
"5": "electric ray, crampfish, numbfish, torpedo",
"6": "stingray",
"7": "cock",
"8": "hen",