Created
July 14, 2025 12:48
-
-
Save maxidl/4c83ccbaae42b7de1592b7fd892cba96 to your computer and use it in GitHub Desktop.
A version of https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py that uses cpu instead of gpu memory to load and save dequantized weights. Only the dequantization step itself is executed on gpu, with much smaller memory footprint compared to the original script. Runtime is longer, but this enables conversion of fp…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from argparse import ArgumentParser | |
from glob import glob | |
from tqdm import tqdm | |
import torch | |
from safetensors.torch import load_file, save_file | |
from kernel import weight_dequant | |
def main(fp8_path, bf16_path): | |
torch.set_default_dtype(torch.bfloat16) | |
os.makedirs(bf16_path, exist_ok=True) | |
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") | |
with open(model_index_file, "r") as f: | |
model_index = json.load(f) | |
weight_map = model_index["weight_map"] | |
# Cache for loaded safetensor files | |
loaded_files = {} | |
fp8_weight_names = [] | |
# Helper function to get tensor from the correct file | |
def get_tensor(tensor_name): | |
file_name = weight_map[tensor_name] | |
if file_name not in loaded_files: | |
file_path = os.path.join(fp8_path, file_name) | |
loaded_files[file_name] = load_file(file_path, device="cpu") | |
return loaded_files[file_name][tensor_name] | |
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) | |
safetensor_files.sort() | |
for safetensor_file in tqdm(safetensor_files): | |
file_name = os.path.basename(safetensor_file) | |
current_state_dict = load_file(safetensor_file, device="cpu") | |
loaded_files[file_name] = current_state_dict | |
new_state_dict = {} | |
for weight_name, weight in current_state_dict.items(): | |
if weight_name.endswith("_scale_inv"): | |
continue | |
elif weight.element_size() == 1: # FP8 weight | |
scale_inv_name = f"{weight_name}_scale_inv" | |
try: | |
# Get scale_inv from the correct file | |
scale_inv = get_tensor(scale_inv_name) | |
fp8_weight_names.append(weight_name) | |
weight_gpu = weight.float().cuda() | |
scale_inv_gpu = scale_inv.cuda() | |
new_state_dict[weight_name] = weight_dequant(weight_gpu, scale_inv_gpu).cpu() | |
except KeyError: | |
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") | |
new_state_dict[weight_name] = weight.cpu() | |
else: | |
new_state_dict[weight_name] = weight.cpu() | |
new_safetensor_file = os.path.join(bf16_path, file_name) | |
save_file(new_state_dict, new_safetensor_file) | |
# Memory management: keep only the 2 most recently used files | |
if len(loaded_files) > 2: | |
oldest_file = next(iter(loaded_files)) | |
del loaded_files[oldest_file] | |
torch.cuda.empty_cache() | |
# Update model index | |
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") | |
for weight_name in fp8_weight_names: | |
scale_inv_name = f"{weight_name}_scale_inv" | |
if scale_inv_name in weight_map: | |
weight_map.pop(scale_inv_name) | |
with open(new_model_index_file, "w") as f: | |
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--input-fp8-hf-path", type=str, required=True) | |
parser.add_argument("--output-bf16-hf-path", type=str, required=True) | |
args = parser.parse_args() | |
main(args.input_fp8_hf_path, args.output_bf16_hf_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment