Skip to content

Instantly share code, notes, and snippets.

@celsowm
Last active June 27, 2025 00:04
Show Gist options
  • Save celsowm/6706529abd1d1134a0012d81bf0ed256 to your computer and use it in GitHub Desktop.
Save celsowm/6706529abd1d1134a0012d81bf0ed256 to your computer and use it in GitHub Desktop.
python/sglang/srt/weight_loader/gguf_loader.py
from __future__ import annotations
import io
import mmap
import pathlib
import struct
import time
from enum import IntEnum
from typing import Any, BinaryIO, Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
# Optional tqdm for progress bar
try:
import tqdm
except ImportError:
tqdm = None
# ────────────────────────────
# basic binary helpers
# ────────────────────────────
_MAGIC = b"GGUF"
_HDR = struct.Struct("<4sIQQ") # magic, version, n_tensors, n_meta
_UINT8 = struct.Struct("<B")
_UINT16 = struct.Struct("<H")
_UINT32 = struct.Struct("<I")
_UINT64 = struct.Struct("<Q")
_FP16 = struct.Struct("<e") # little-endian IEEE-754 half
# ────────────────────────────
# GGML / GGUF enums
# ────────────────────────────
class GGMLType(IntEnum):
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q8_0 = 6
Q8_1 = 7
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
I8 = 16
I16 = 17
I32 = 18
I64 = 19
F64 = 20
BF16 = 21
# aliases for older ggml versions where the integer ids shift by +1
I8_ALT = 17
I16_ALT = 18
I32_ALT = 19
I64_ALT = 20
class GGUFMetaValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
_META_VALUE_STRUCTS = {
GGUFMetaValueType.UINT8: _UINT8,
GGUFMetaValueType.INT8: struct.Struct("<b"),
GGUFMetaValueType.UINT16: _UINT16,
GGUFMetaValueType.INT16: struct.Struct("<h"),
GGUFMetaValueType.UINT32: _UINT32,
GGUFMetaValueType.INT32: struct.Struct("<i"),
GGUFMetaValueType.FLOAT32: struct.Struct("<f"),
GGUFMetaValueType.UINT64: _UINT64,
GGUFMetaValueType.INT64: struct.Struct("<q"),
GGUFMetaValueType.FLOAT64: struct.Struct("<d"),
}
# ────────────────────────────
# De-quant helpers
# ────────────────────────────
def _dequant_q8_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS
dst[i*BS:(i+1)*BS] = q * scale
def _dequant_q4_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16
q_lo = (qbytes & 0x0F).astype(np.float32) - 8.0
q_hi = (qbytes >> 4).astype(np.float32) - 8.0
blk = np.empty(BS, dtype=np.float32)
blk[0::2] = q_lo * scale
blk[1::2] = q_hi * scale
dst[i*BS:(i+1)*BS] = blk
def _dequant_q4_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
min_ = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16
lo = (qbytes & 0x0F).astype(np.float32)
hi = (qbytes >> 4).astype(np.float32)
blk = np.empty(BS, dtype=np.float32)
blk[0::2] = lo * scale + min_
blk[1::2] = hi * scale + min_
dst[i*BS:(i+1)*BS] = blk
def _dequant_q8_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
zero = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS
dst[i*BS:(i+1)*BS] = q * scale + zero
def _dequant_q4_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 256
for i in range(n_blocks):
d = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2
m = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2
d = np.nan_to_num(d, nan=0.0)
m = np.nan_to_num(m, nan=0.0)
scales_and_mins_raw = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16
scales = (scales_and_mins_raw & 0x0F).astype(np.float64) * d
mins = (scales_and_mins_raw >> 4).astype(np.float64) * m
q = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128
blk_dst = dst[i*BS:(i+1)*BS]
for j in range(16):
start, end = j*16, (j+1)*16
q_lo = q[j*8:(j+1)*8] & 0x0F
q_hi = q[j*8:(j+1)*8] >> 4
sub_q = np.concatenate([q_lo, q_hi]).astype(np.float32)
dequantized_block = sub_q * scales[j] + mins[j]
np.nan_to_num(dequantized_block, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
blk_dst[start:end] = dequantized_block
def _dequant_q6_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 256
for i in range(n_blocks):
scale_super = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2
scale_super = np.nan_to_num(scale_super, nan=0.0)
scales = np.frombuffer(buf[pos:pos+16], dtype=np.int8).astype(np.float64); pos += 16
ql = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128
qh = np.frombuffer(buf[pos:pos+64], dtype=np.uint8); pos += 64
q = np.empty(BS, dtype=np.int8)
q[0:128] = (ql & 0x0F)
q[128:256] = (ql >> 4)
q[0:64] |= ((qh >> 0) & 3) << 4
q[64:128] |= ((qh >> 2) & 3) << 4
q[128:192] |= ((qh >> 4) & 3) << 4
q[192:256] |= ((qh >> 6) & 3) << 4
q_final = (q - 32).astype(np.float32)
blk_dst = dst[i*BS:(i+1)*BS]
for j in range(16):
start, end = j*16, (j+1)*16
dequantized_block = scale_super * scales[j] * q_final[start:end]
np.nan_to_num(dequantized_block, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
blk_dst[start:end] = dequantized_block
# Type-info map
_TYPE_INFO: Dict[GGMLType, Dict[str, Any]] = {
# primitive
GGMLType.F32: dict(torch_dtype=torch.float32),
GGMLType.F16: dict(torch_dtype=torch.float16),
GGMLType.BF16: dict(torch_dtype=torch.bfloat16),
GGMLType.F64: dict(torch_dtype=torch.float64),
GGMLType.I8: dict(torch_dtype=torch.int8),
GGMLType.I16: dict(torch_dtype=torch.int16),
GGMLType.I32: dict(torch_dtype=torch.int32),
GGMLType.I64: dict(torch_dtype=torch.int64),
GGMLType.I8_ALT: dict(torch_dtype=torch.int8),
GGMLType.I16_ALT: dict(torch_dtype=torch.int16),
GGMLType.I32_ALT: dict(torch_dtype=torch.int32),
GGMLType.I64_ALT: dict(torch_dtype=torch.int64),
# quantised (supported)
GGMLType.Q8_0: dict(block=32, size=34, dequant=_dequant_q8_0),
GGMLType.Q4_0: dict(block=32, size=18, dequant=_dequant_q4_0),
GGMLType.Q4_1: dict(block=32, size=20, dequant=_dequant_q4_1),
GGMLType.Q8_1: dict(block=32, size=36, dequant=_dequant_q8_1),
GGMLType.Q4_K: dict(block=256, size=148, dequant=_dequant_q4_k),
GGMLType.Q6_K: dict(block=256, size=210, dequant=_dequant_q6_k),
# quantised (not yet ported)
GGMLType.Q2_K: dict(supported=False),
GGMLType.Q3_K: dict(supported=False),
GGMLType.Q5_K: dict(supported=False),
GGMLType.Q8_K: dict(supported=False),
}
# ────────────────────────────
# meta helpers
# ────────────────────────────
def _read_string(buf: BinaryIO, file_size: int) -> str:
length = _UINT64.unpack(buf.read(8))[0]
if buf.tell() + length > file_size:
raise ValueError("GGUF file truncated while reading string")
return buf.read(length).decode("utf-8")
def _read_meta_value(buf: BinaryIO, vt: GGUFMetaValueType, file_size: int) -> Any:
s = _META_VALUE_STRUCTS.get(vt)
if s:
return s.unpack(buf.read(s.size))[0]
if vt == GGUFMetaValueType.BOOL:
return struct.unpack("<?", buf.read(1))[0]
if vt == GGUFMetaValueType.STRING:
return _read_string(buf, file_size)
if vt == GGUFMetaValueType.ARRAY:
item_type = GGUFMetaValueType(_UINT32.unpack(buf.read(4))[0])
count = _UINT64.unpack(buf.read(8))[0]
return [_read_meta_value(buf, item_type, file_size) for _ in range(count)]
raise ValueError(f"unhandled meta type {vt}")
# ────────────────────────────
# tensor-descriptor
# ────────────────────────────
class _TensorInfo:
__slots__ = ("name", "shape", "dtype", "offset", "n_bytes")
def __init__(self, name: str, shape: List[int], dtype: GGMLType, offset: int):
self.name = name
self.shape = shape
self.dtype = dtype
self.offset = offset
self.n_bytes = 0
# ────────────────────────────
# public loader
# ────────────────────────────
class GGUFLoader:
def __init__(self, path: str | pathlib.Path, target_dtype: torch.dtype = torch.float16):
self._path = pathlib.Path(path)
self._target_dtype = target_dtype
def load(self) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor]]:
if not self._path.is_file():
raise FileNotFoundError(self._path)
meta, tensors, tensor_data_buffers = self._read_structure()
weights: Dict[str, torch.Tensor] = {}
tensor_iterator = zip(tensors, tensor_data_buffers)
if tqdm:
tensor_iterator = tqdm.tqdm(
tensor_iterator,
total=len(tensors),
desc="Loading tensors",
unit="tensors"
)
for t, raw_data in tensor_iterator:
info = _TYPE_INFO.get(t.dtype)
if info is None:
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}")
if not info.get("supported", True):
raise NotImplementedError(f"{t.dtype.name} not yet supported")
numel = int(np.prod(t.shape)) if t.shape else 1
if "torch_dtype" in info:
buf = bytearray(raw_data)
tensor = torch.frombuffer(buf, dtype=info["torch_dtype"]).reshape(*t.shape).clone()
else:
bs, fn = info["block"], info["dequant"]
out = np.empty(numel, dtype=np.float32)
fn(memoryview(raw_data), out, numel // bs)
tensor = torch.from_numpy(out).reshape(*t.shape)
weights[t.name] = tensor.to(self._target_dtype)
return meta, weights
def _read_structure(self) -> Tuple[Dict[str, Any], List[_TensorInfo], List[bytes]]:
tensor_data_buffers = []
with self._path.open("rb") as fh, mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) as mm:
file_size = len(mm)
bf = io.BytesIO(mm)
magic, version, n_tensors, n_meta = _HDR.unpack(bf.read(_HDR.size))
if magic != _MAGIC:
raise ValueError("not a GGUF file")
if version > 3:
print(f"warning: GGUF v{version} – parser tested up to v3")
meta: Dict[str, Any] = {}
for _ in range(n_meta):
key = _read_string(bf, file_size)
vt = GGUFMetaValueType(_UINT32.unpack(bf.read(4))[0])
meta[key] = _read_meta_value(bf, vt, file_size)
alignment = int(meta.get("general.alignment", 32))
tensors: List[_TensorInfo] = []
for _ in range(n_tensors):
name = _read_string(bf, file_size)
ndims = _UINT32.unpack(bf.read(4))[0]
dims = [_UINT64.unpack(bf.read(8))[0] for _ in range(ndims)]
shape = list(reversed(dims))
dtype = GGMLType(_UINT32.unpack(bf.read(4))[0])
offset = _UINT64.unpack(bf.read(8))[0]
tensors.append(_TensorInfo(name, shape, dtype, offset))
data_start_offset = bf.tell()
padding = -data_start_offset % alignment
data_start_offset += padding
for t in tensors:
t.offset += data_start_offset
numel = int(np.prod(t.shape)) if t.shape else 1
info = _TYPE_INFO.get(t.dtype)
if info is None:
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}")
if "torch_dtype" in info:
elt = torch.tensor([], dtype=info["torch_dtype"])
t.n_bytes = numel * elt.element_size()
elif "block" in info:
bs, blk_bytes = info["block"], info["size"]
if numel % bs:
raise ValueError(f"{t.name}: {numel} not mult of {bs}")
t.n_bytes = (numel // bs) * blk_bytes
else:
if not info.get("supported", True):
t.n_bytes = 0
else:
raise ValueError(f"Tensor {t.name} has quant type {t.dtype.name} with no size info.")
if t.offset % alignment:
raise ValueError(f"{t.name} offset {t.offset} not aligned ({alignment})")
if t.n_bytes > 0:
tensor_data_buffers.append(mm[t.offset : t.offset + t.n_bytes])
elif not info.get("supported", True):
tensor_data_buffers.append(b'')
return meta, tensors, tensor_data_buffers
# ───────────────────────────────────────────────────
# ─── INFERENCE EXAMPLE AND PYTORCH MODEL BELOW ───
# ───────────────────────────────────────────────────
class ModelArgs:
def __init__(self, meta: Dict[str, Any], target_dtype: torch.dtype):
self.dtype = target_dtype
self.dim = meta["llama.embedding_length"]
self.n_layers = meta["llama.block_count"]
self.n_heads = meta["llama.attention.head_count"]
self.n_kv_heads = meta["llama.attention.head_count_kv"]
self.vocab_size = meta["llama.vocab_size"]
self.hidden_dim = meta["llama.feed_forward_length"]
self.rope_dim = meta.get("llama.rope.dimension_count", self.dim // self.n_heads)
self.norm_eps = meta["llama.attention.layer_norm_rms_epsilon"]
self.max_seq_len = meta.get("llama.context_length", 2048)
self.rope_freq_base = meta.get("llama.rope.freq_base", 10000.0)
self.bos_token_id = meta.get("tokenizer.bos_token_id", 1)
self.eos_token_id = meta.get("tokenizer.eos_token_id", 2)
self.pad_token_id = meta.get("tokenizer.pad_token_id", -1)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class RoPE:
def __init__(self, dim: int, max_seq_len: int, base: float, dtype: torch.dtype):
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
self.dtype = dtype
self._precompute_freqs_cis()
def _precompute_freqs_cis(self):
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
t = torch.arange(self.max_seq_len, device=freqs.device)
freqs_cis = torch.polar(torch.ones_like(t[:, None] * freqs[None, :]), t[:, None] * freqs[None, :])
self.freqs_cis = freqs_cis
def __call__(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int):
seq_len = xq.shape[1]
freqs = self.freqs_cis[start_pos : start_pos + seq_len].to(xq.device)
# Perform rotation in float16/complex32 to avoid dtype mismatch issues.
# This is safer and has minimal precision impact for inference.
# Cast rotational frequencies to complex32 (backed by float16)
freqs_c32 = freqs.to(torch.complex32)
freqs_c32 = freqs_c32.view(1, seq_len, 1, -1)
# Reshape xq and xk to complex32
# .reshape() is needed to pair up the last dimension for complex numbers
xq_c32 = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2))
xk_c32 = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2))
# Apply RoPE and flatten back to original shape
xq_rotated = torch.view_as_real(xq_c32 * freqs_c32).flatten(3)
xk_rotated = torch.view_as_real(xk_c32 * freqs_c32).flatten(3)
# --- FIX ---
# Cast back to the original dtype to prevent type promotion issues.
return xq_rotated.to(xq.dtype), xk_rotated.to(xk.dtype)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads = args.n_heads
self.n_kv_heads = args.n_kv_heads
self.head_dim = args.dim // self.n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
# CORRECTED CODE
# Register cache tensors as buffers to ensure they are moved to the correct device
# along with the model.
self.register_buffer('cache_k', torch.zeros((1, args.max_seq_len, self.n_kv_heads, self.head_dim), dtype=args.dtype))
self.register_buffer('cache_v', torch.zeros((1, args.max_seq_len, self.n_kv_heads, self.head_dim), dtype=args.dtype))
def forward(self, x: torch.Tensor, start_pos: int, rope: RoPE):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
xq, xk = rope(xq, xk, start_pos)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
keys = keys[:, :, :, None, :].expand(bsz, -1, self.n_kv_heads, self.n_rep, self.head_dim).reshape(bsz, -1, self.n_heads, self.head_dim)
values = values[:, :, :, None, :].expand(bsz, -1, self.n_kv_heads, self.n_rep, self.head_dim).reshape(bsz, -1, self.n_heads, self.head_dim)
# Ensure query, key, and value tensors share the same dtype
tgt_dtype = keys.dtype # usually torch.float16
if xq.dtype != tgt_dtype:
xq = xq.to(tgt_dtype)
if values.dtype != tgt_dtype:
values = values.to(tgt_dtype)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
is_causal = seqlen > 1
scores = torch.nn.functional.scaled_dot_product_attention(xq, keys, values, is_causal=is_causal)
output = scores.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
# FIX: Ensure output dtype matches wo.weight dtype
return self.wo(output.to(self.wo.weight.dtype))
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def forward(self, x):
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(self, x: torch.Tensor, start_pos: int, rope: RoPE):
h = x + self.attention(self.attention_norm(x), start_pos, rope)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
self.rope = RoPE(args.rope_dim, args.max_seq_len, args.rope_freq_base, args.dtype)
def forward(self, tokens: torch.Tensor, start_pos: int):
h = self.tok_embeddings(tokens)
for layer in self.layers:
h = layer(h, start_pos, self.rope)
h = self.norm(h)
return self.output(h).float()
def reset_kv_cache(self):
for layer in self.layers:
layer.attention.cache_k.zero_()
layer.attention.cache_v.zero_()
class SimpleTokenizer:
def __init__(self, meta: Dict[str, Any]):
# Convert each token to bytes when initializing
self.tokens = [token.encode('utf-8') if isinstance(token, str) else token for token in meta['tokenizer.ggml.tokens']]
self.scores = meta.get('tokenizer.ggml.scores')
self.tok_map = {tok: i for i, tok in enumerate(self.tokens)}
self.bos_token_id = int(meta.get('tokenizer.bos_token_id', 1))
self.eos_token_id = int(meta.get('tokenizer.eos_token_id', 2))
self.unk_token_id = int(meta.get('tokenizer.unk_token_id', 0))
def encode(self, text: str, bos: bool = True) -> List[int]:
token_ids = []
if bos:
token_ids.append(self.bos_token_id)
for char_byte in text.encode('utf-8'):
char = bytes([char_byte])
token_ids.append(self.tok_map.get(char, self.unk_token_id))
return token_ids
def decode(self, token_ids: List[int]) -> str:
filtered_ids = [tid for tid in token_ids if tid not in (self.bos_token_id, self.eos_token_id)]
# This line will now work correctly as self.tokens will contain bytes objects
byte_tokens = b"".join(self.tokens[tid] for tid in filtered_ids if tid < len(self.tokens))
return byte_tokens.decode('utf-8', errors='replace')
def map_gguf_weights_to_pytorch(
gguf_weights: Dict[str, torch.Tensor], args: ModelArgs
) -> Dict[str, torch.Tensor]:
pytorch_weights = {}
for name, tensor in gguf_weights.items():
if name == "token_embd.weight":
pytorch_weights["tok_embeddings.weight"] = tensor
elif name == "output_norm.weight":
pytorch_weights["norm.weight"] = tensor
elif name == "output.weight" or name == "lm_head.weight":
pytorch_weights["output.weight"] = tensor
elif name.startswith("blk."):
parts = name.split('.')
layer_num = parts[1]
layer_type_full = ".".join(parts[2:])
if layer_type_full == "attn_norm.weight":
new_name = f"layers.{layer_num}.attention_norm.weight"
elif layer_type_full == "ffn_norm.weight":
new_name = f"layers.{layer_num}.ffn_norm.weight"
elif layer_type_full == "attn_q.weight":
new_name = f"layers.{layer_num}.attention.wq.weight"
elif layer_type_full == "attn_k.weight":
new_name = f"layers.{layer_num}.attention.wk.weight"
elif layer_type_full == "attn_v.weight":
new_name = f"layers.{layer_num}.attention.wv.weight"
elif layer_type_full == "attn_output.weight":
new_name = f"layers.{layer_num}.attention.wo.weight"
elif layer_type_full == "ffn_gate.weight":
new_name = f"layers.{layer_num}.feed_forward.w1.weight"
elif layer_type_full == "ffn_down.weight":
new_name = f"layers.{layer_num}.feed_forward.w2.weight"
elif layer_type_full == "ffn_up.weight":
new_name = f"layers.{layer_num}.feed_forward.w3.weight"
else:
print(f"Warning: Unhandled GGUF block tensor: {name}")
continue
pytorch_weights[new_name] = tensor
else:
if name != "rope_freqs.weight":
print(f"Warning: Unhandled GGUF tensor (will be ignored): {name}")
if "output.weight" not in pytorch_weights and "tok_embeddings.weight" in pytorch_weights:
print("Info: Tying output.weight to tok_embeddings.weight.")
pytorch_weights["output.weight"] = pytorch_weights["tok_embeddings.weight"]
expected_keys = Transformer(args).state_dict().keys()
missing_keys = [k for k in expected_keys if k not in pytorch_weights]
if missing_keys:
print(f"Warning: The following keys were expected but not found: {missing_keys}")
return pytorch_weights
@torch.no_grad()
def generate(
model: Transformer,
tokenizer: SimpleTokenizer,
prompt: str,
max_tokens: int = 100
):
model.eval()
model.reset_kv_cache()
token_ids = tokenizer.encode(prompt, bos=True)
if not token_ids:
print("\nPrompt is empty.")
return
# Determine the model's device and move the input tensor to it.
device = next(model.parameters()).device
prompt_tokens = torch.tensor([token_ids], dtype=torch.long, device=device)
start_time = time.time()
logits = model(prompt_tokens, 0)
prefill_time = time.time() - start_time
generated_ids = []
# Subsequent tensors will be on the correct device because they come from the model's output
current_token = torch.argmax(logits[:, -1, :], dim=-1)
start_gen_time = time.time()
for i in range(max_tokens):
if current_token.item() == tokenizer.eos_token_id:
break
generated_ids.append(current_token.item())
print(tokenizer.decode([current_token.item()]), end="", flush=True)
logits = model(current_token.view(1, 1), len(token_ids) + i)
current_token = torch.argmax(logits[:, -1, :], dim=-1)
gen_time = time.time() - start_gen_time
print("\n\n--- Stats ---")
prefill_tok_per_sec = len(token_ids) / prefill_time if prefill_time > 0 else float('inf')
gen_tok_per_sec = len(generated_ids) / gen_time if gen_time > 0 else float('inf')
print(f"Prefill time: {prefill_time:.2f}s ({prefill_tok_per_sec:.2f} t/s)")
print(f"Generation time: {gen_time:.2f}s ({gen_tok_per_sec:.2f} t/s)")
print(f"Total tokens generated: {len(generated_ids)}")
if __name__ == "__main__":
import sys
if not torch:
print("Error: 'torch' is not installed. Please run: pip install torch", file=sys.stderr)
sys.exit(1)
if not tqdm:
print("Note: 'tqdm' not found. To see a loading progress bar, run: pip install tqdm", file=sys.stderr)
# --- Configuration ---
# <<< PLEASE REPLACE THIS WITH THE PATH TO YOUR GGUF FILE >>>
GGUF_PATH = r'D:\ia\levi_chat\llm\Llama-3.2-1B-Instruct-Q4_K_M.gguf'
TARGET_DTYPE = torch.float16
PROMPT = "The capital of France is"
MAX_TOKENS_TO_GENERATE = 100
try:
# --- 1. Load GGUF file ---
print(f"Loading GGUF file: {GGUF_PATH}")
start_load_time = time.time()
loader = GGUFLoader(GGUF_PATH, target_dtype=TARGET_DTYPE)
meta, gguf_weights = loader.load()
print(f"GGUF loaded in {time.time() - start_load_time:.2f}s")
if 'general.architecture' not in meta or meta['general.architecture'] != 'llama':
print(f"Error: This script is designed for 'llama' architecture, "
f"but the model is '{meta.get('general.architecture', 'unknown')}'.", file=sys.stderr)
sys.exit(1)
# --- 2. Set up Model and Tokenizer ---
args = ModelArgs(meta, TARGET_DTYPE)
model = Transformer(args)
tokenizer = SimpleTokenizer(meta)
# Move model to a device if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# --- 3. Map weights and load into model ---
print("\nMapping and loading weights into the PyTorch model...")
pytorch_weights = map_gguf_weights_to_pytorch(gguf_weights, args)
model.load_state_dict(pytorch_weights, strict=False)
print("Weights loaded successfully.")
# --- Print Model Info ---
print("\n--- Model Info ---")
print(f"- Architecture: {meta.get('general.architecture')}")
print(f"- Layers: {args.n_layers}, Heads: {args.n_heads}, Dim: {args.dim}")
print(f"- Target DType: {TARGET_DTYPE}")
print(f"- Vocab Size: {args.vocab_size}")
print(f"- Inference Device: {device}")
# --- 4. Run Inference ---
print(f"\n--- Inference ---")
print(f"Prompt: {PROMPT}")
print("Response:", end=" ", flush=True)
generate(model, tokenizer, PROMPT, max_tokens=MAX_TOKENS_TO_GENERATE)
print("\n\nSuccessfully loaded and ran GGUF file.")
except (FileNotFoundError, NotImplementedError, ValueError, KeyError) as e:
print(f"\nError: {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
import traceback
print(f"\nAn unexpected error occurred: {e}", file=sys.stderr)
traceback.print_exc()
sys.exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment