Last active
June 27, 2025 00:04
-
-
Save celsowm/6706529abd1d1134a0012d81bf0ed256 to your computer and use it in GitHub Desktop.
python/sglang/srt/weight_loader/gguf_loader.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import io | |
import mmap | |
import pathlib | |
import struct | |
import time | |
from enum import IntEnum | |
from typing import Any, BinaryIO, Dict, List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
# Optional tqdm for progress bar | |
try: | |
import tqdm | |
except ImportError: | |
tqdm = None | |
# ──────────────────────────── | |
# basic binary helpers | |
# ──────────────────────────── | |
_MAGIC = b"GGUF" | |
_HDR = struct.Struct("<4sIQQ") # magic, version, n_tensors, n_meta | |
_UINT8 = struct.Struct("<B") | |
_UINT16 = struct.Struct("<H") | |
_UINT32 = struct.Struct("<I") | |
_UINT64 = struct.Struct("<Q") | |
_FP16 = struct.Struct("<e") # little-endian IEEE-754 half | |
# ──────────────────────────── | |
# GGML / GGUF enums | |
# ──────────────────────────── | |
class GGMLType(IntEnum): | |
F32 = 0 | |
F16 = 1 | |
Q4_0 = 2 | |
Q4_1 = 3 | |
Q8_0 = 6 | |
Q8_1 = 7 | |
Q2_K = 10 | |
Q3_K = 11 | |
Q4_K = 12 | |
Q5_K = 13 | |
Q6_K = 14 | |
Q8_K = 15 | |
I8 = 16 | |
I16 = 17 | |
I32 = 18 | |
I64 = 19 | |
F64 = 20 | |
BF16 = 21 | |
# aliases for older ggml versions where the integer ids shift by +1 | |
I8_ALT = 17 | |
I16_ALT = 18 | |
I32_ALT = 19 | |
I64_ALT = 20 | |
class GGUFMetaValueType(IntEnum): | |
UINT8 = 0 | |
INT8 = 1 | |
UINT16 = 2 | |
INT16 = 3 | |
UINT32 = 4 | |
INT32 = 5 | |
FLOAT32 = 6 | |
BOOL = 7 | |
STRING = 8 | |
ARRAY = 9 | |
UINT64 = 10 | |
INT64 = 11 | |
FLOAT64 = 12 | |
_META_VALUE_STRUCTS = { | |
GGUFMetaValueType.UINT8: _UINT8, | |
GGUFMetaValueType.INT8: struct.Struct("<b"), | |
GGUFMetaValueType.UINT16: _UINT16, | |
GGUFMetaValueType.INT16: struct.Struct("<h"), | |
GGUFMetaValueType.UINT32: _UINT32, | |
GGUFMetaValueType.INT32: struct.Struct("<i"), | |
GGUFMetaValueType.FLOAT32: struct.Struct("<f"), | |
GGUFMetaValueType.UINT64: _UINT64, | |
GGUFMetaValueType.INT64: struct.Struct("<q"), | |
GGUFMetaValueType.FLOAT64: struct.Struct("<d"), | |
} | |
# ──────────────────────────── | |
# De-quant helpers | |
# ──────────────────────────── | |
def _dequant_q8_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS | |
dst[i*BS:(i+1)*BS] = q * scale | |
def _dequant_q4_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16 | |
q_lo = (qbytes & 0x0F).astype(np.float32) - 8.0 | |
q_hi = (qbytes >> 4).astype(np.float32) - 8.0 | |
blk = np.empty(BS, dtype=np.float32) | |
blk[0::2] = q_lo * scale | |
blk[1::2] = q_hi * scale | |
dst[i*BS:(i+1)*BS] = blk | |
def _dequant_q4_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
min_ = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16 | |
lo = (qbytes & 0x0F).astype(np.float32) | |
hi = (qbytes >> 4).astype(np.float32) | |
blk = np.empty(BS, dtype=np.float32) | |
blk[0::2] = lo * scale + min_ | |
blk[1::2] = hi * scale + min_ | |
dst[i*BS:(i+1)*BS] = blk | |
def _dequant_q8_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
zero = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS | |
dst[i*BS:(i+1)*BS] = q * scale + zero | |
def _dequant_q4_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 256 | |
for i in range(n_blocks): | |
d = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
m = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
d = np.nan_to_num(d, nan=0.0) | |
m = np.nan_to_num(m, nan=0.0) | |
scales_and_mins_raw = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16 | |
scales = (scales_and_mins_raw & 0x0F).astype(np.float64) * d | |
mins = (scales_and_mins_raw >> 4).astype(np.float64) * m | |
q = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128 | |
blk_dst = dst[i*BS:(i+1)*BS] | |
for j in range(16): | |
start, end = j*16, (j+1)*16 | |
q_lo = q[j*8:(j+1)*8] & 0x0F | |
q_hi = q[j*8:(j+1)*8] >> 4 | |
sub_q = np.concatenate([q_lo, q_hi]).astype(np.float32) | |
dequantized_block = sub_q * scales[j] + mins[j] | |
np.nan_to_num(dequantized_block, copy=False, nan=0.0, posinf=0.0, neginf=0.0) | |
blk_dst[start:end] = dequantized_block | |
def _dequant_q6_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 256 | |
for i in range(n_blocks): | |
scale_super = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
scale_super = np.nan_to_num(scale_super, nan=0.0) | |
scales = np.frombuffer(buf[pos:pos+16], dtype=np.int8).astype(np.float64); pos += 16 | |
ql = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128 | |
qh = np.frombuffer(buf[pos:pos+64], dtype=np.uint8); pos += 64 | |
q = np.empty(BS, dtype=np.int8) | |
q[0:128] = (ql & 0x0F) | |
q[128:256] = (ql >> 4) | |
q[0:64] |= ((qh >> 0) & 3) << 4 | |
q[64:128] |= ((qh >> 2) & 3) << 4 | |
q[128:192] |= ((qh >> 4) & 3) << 4 | |
q[192:256] |= ((qh >> 6) & 3) << 4 | |
q_final = (q - 32).astype(np.float32) | |
blk_dst = dst[i*BS:(i+1)*BS] | |
for j in range(16): | |
start, end = j*16, (j+1)*16 | |
dequantized_block = scale_super * scales[j] * q_final[start:end] | |
np.nan_to_num(dequantized_block, copy=False, nan=0.0, posinf=0.0, neginf=0.0) | |
blk_dst[start:end] = dequantized_block | |
# Type-info map | |
_TYPE_INFO: Dict[GGMLType, Dict[str, Any]] = { | |
# primitive | |
GGMLType.F32: dict(torch_dtype=torch.float32), | |
GGMLType.F16: dict(torch_dtype=torch.float16), | |
GGMLType.BF16: dict(torch_dtype=torch.bfloat16), | |
GGMLType.F64: dict(torch_dtype=torch.float64), | |
GGMLType.I8: dict(torch_dtype=torch.int8), | |
GGMLType.I16: dict(torch_dtype=torch.int16), | |
GGMLType.I32: dict(torch_dtype=torch.int32), | |
GGMLType.I64: dict(torch_dtype=torch.int64), | |
GGMLType.I8_ALT: dict(torch_dtype=torch.int8), | |
GGMLType.I16_ALT: dict(torch_dtype=torch.int16), | |
GGMLType.I32_ALT: dict(torch_dtype=torch.int32), | |
GGMLType.I64_ALT: dict(torch_dtype=torch.int64), | |
# quantised (supported) | |
GGMLType.Q8_0: dict(block=32, size=34, dequant=_dequant_q8_0), | |
GGMLType.Q4_0: dict(block=32, size=18, dequant=_dequant_q4_0), | |
GGMLType.Q4_1: dict(block=32, size=20, dequant=_dequant_q4_1), | |
GGMLType.Q8_1: dict(block=32, size=36, dequant=_dequant_q8_1), | |
GGMLType.Q4_K: dict(block=256, size=148, dequant=_dequant_q4_k), | |
GGMLType.Q6_K: dict(block=256, size=210, dequant=_dequant_q6_k), | |
# quantised (not yet ported) | |
GGMLType.Q2_K: dict(supported=False), | |
GGMLType.Q3_K: dict(supported=False), | |
GGMLType.Q5_K: dict(supported=False), | |
GGMLType.Q8_K: dict(supported=False), | |
} | |
# ──────────────────────────── | |
# meta helpers | |
# ──────────────────────────── | |
def _read_string(buf: BinaryIO, file_size: int) -> str: | |
length = _UINT64.unpack(buf.read(8))[0] | |
if buf.tell() + length > file_size: | |
raise ValueError("GGUF file truncated while reading string") | |
return buf.read(length).decode("utf-8") | |
def _read_meta_value(buf: BinaryIO, vt: GGUFMetaValueType, file_size: int) -> Any: | |
s = _META_VALUE_STRUCTS.get(vt) | |
if s: | |
return s.unpack(buf.read(s.size))[0] | |
if vt == GGUFMetaValueType.BOOL: | |
return struct.unpack("<?", buf.read(1))[0] | |
if vt == GGUFMetaValueType.STRING: | |
return _read_string(buf, file_size) | |
if vt == GGUFMetaValueType.ARRAY: | |
item_type = GGUFMetaValueType(_UINT32.unpack(buf.read(4))[0]) | |
count = _UINT64.unpack(buf.read(8))[0] | |
return [_read_meta_value(buf, item_type, file_size) for _ in range(count)] | |
raise ValueError(f"unhandled meta type {vt}") | |
# ──────────────────────────── | |
# tensor-descriptor | |
# ──────────────────────────── | |
class _TensorInfo: | |
__slots__ = ("name", "shape", "dtype", "offset", "n_bytes") | |
def __init__(self, name: str, shape: List[int], dtype: GGMLType, offset: int): | |
self.name = name | |
self.shape = shape | |
self.dtype = dtype | |
self.offset = offset | |
self.n_bytes = 0 | |
# ──────────────────────────── | |
# public loader | |
# ──────────────────────────── | |
class GGUFLoader: | |
def __init__(self, path: str | pathlib.Path, target_dtype: torch.dtype = torch.float16): | |
self._path = pathlib.Path(path) | |
self._target_dtype = target_dtype | |
def load(self) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor]]: | |
if not self._path.is_file(): | |
raise FileNotFoundError(self._path) | |
meta, tensors, tensor_data_buffers = self._read_structure() | |
weights: Dict[str, torch.Tensor] = {} | |
tensor_iterator = zip(tensors, tensor_data_buffers) | |
if tqdm: | |
tensor_iterator = tqdm.tqdm( | |
tensor_iterator, | |
total=len(tensors), | |
desc="Loading tensors", | |
unit="tensors" | |
) | |
for t, raw_data in tensor_iterator: | |
info = _TYPE_INFO.get(t.dtype) | |
if info is None: | |
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}") | |
if not info.get("supported", True): | |
raise NotImplementedError(f"{t.dtype.name} not yet supported") | |
numel = int(np.prod(t.shape)) if t.shape else 1 | |
if "torch_dtype" in info: | |
buf = bytearray(raw_data) | |
tensor = torch.frombuffer(buf, dtype=info["torch_dtype"]).reshape(*t.shape).clone() | |
else: | |
bs, fn = info["block"], info["dequant"] | |
out = np.empty(numel, dtype=np.float32) | |
fn(memoryview(raw_data), out, numel // bs) | |
tensor = torch.from_numpy(out).reshape(*t.shape) | |
weights[t.name] = tensor.to(self._target_dtype) | |
return meta, weights | |
def _read_structure(self) -> Tuple[Dict[str, Any], List[_TensorInfo], List[bytes]]: | |
tensor_data_buffers = [] | |
with self._path.open("rb") as fh, mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) as mm: | |
file_size = len(mm) | |
bf = io.BytesIO(mm) | |
magic, version, n_tensors, n_meta = _HDR.unpack(bf.read(_HDR.size)) | |
if magic != _MAGIC: | |
raise ValueError("not a GGUF file") | |
if version > 3: | |
print(f"warning: GGUF v{version} – parser tested up to v3") | |
meta: Dict[str, Any] = {} | |
for _ in range(n_meta): | |
key = _read_string(bf, file_size) | |
vt = GGUFMetaValueType(_UINT32.unpack(bf.read(4))[0]) | |
meta[key] = _read_meta_value(bf, vt, file_size) | |
alignment = int(meta.get("general.alignment", 32)) | |
tensors: List[_TensorInfo] = [] | |
for _ in range(n_tensors): | |
name = _read_string(bf, file_size) | |
ndims = _UINT32.unpack(bf.read(4))[0] | |
dims = [_UINT64.unpack(bf.read(8))[0] for _ in range(ndims)] | |
shape = list(reversed(dims)) | |
dtype = GGMLType(_UINT32.unpack(bf.read(4))[0]) | |
offset = _UINT64.unpack(bf.read(8))[0] | |
tensors.append(_TensorInfo(name, shape, dtype, offset)) | |
data_start_offset = bf.tell() | |
padding = -data_start_offset % alignment | |
data_start_offset += padding | |
for t in tensors: | |
t.offset += data_start_offset | |
numel = int(np.prod(t.shape)) if t.shape else 1 | |
info = _TYPE_INFO.get(t.dtype) | |
if info is None: | |
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}") | |
if "torch_dtype" in info: | |
elt = torch.tensor([], dtype=info["torch_dtype"]) | |
t.n_bytes = numel * elt.element_size() | |
elif "block" in info: | |
bs, blk_bytes = info["block"], info["size"] | |
if numel % bs: | |
raise ValueError(f"{t.name}: {numel} not mult of {bs}") | |
t.n_bytes = (numel // bs) * blk_bytes | |
else: | |
if not info.get("supported", True): | |
t.n_bytes = 0 | |
else: | |
raise ValueError(f"Tensor {t.name} has quant type {t.dtype.name} with no size info.") | |
if t.offset % alignment: | |
raise ValueError(f"{t.name} offset {t.offset} not aligned ({alignment})") | |
if t.n_bytes > 0: | |
tensor_data_buffers.append(mm[t.offset : t.offset + t.n_bytes]) | |
elif not info.get("supported", True): | |
tensor_data_buffers.append(b'') | |
return meta, tensors, tensor_data_buffers | |
# ─────────────────────────────────────────────────── | |
# ─── INFERENCE EXAMPLE AND PYTORCH MODEL BELOW ─── | |
# ─────────────────────────────────────────────────── | |
class ModelArgs: | |
def __init__(self, meta: Dict[str, Any], target_dtype: torch.dtype): | |
self.dtype = target_dtype | |
self.dim = meta["llama.embedding_length"] | |
self.n_layers = meta["llama.block_count"] | |
self.n_heads = meta["llama.attention.head_count"] | |
self.n_kv_heads = meta["llama.attention.head_count_kv"] | |
self.vocab_size = meta["llama.vocab_size"] | |
self.hidden_dim = meta["llama.feed_forward_length"] | |
self.rope_dim = meta.get("llama.rope.dimension_count", self.dim // self.n_heads) | |
self.norm_eps = meta["llama.attention.layer_norm_rms_epsilon"] | |
self.max_seq_len = meta.get("llama.context_length", 2048) | |
self.rope_freq_base = meta.get("llama.rope.freq_base", 10000.0) | |
self.bos_token_id = meta.get("tokenizer.bos_token_id", 1) | |
self.eos_token_id = meta.get("tokenizer.eos_token_id", 2) | |
self.pad_token_id = meta.get("tokenizer.pad_token_id", -1) | |
class RMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x): | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
output = self._norm(x.float()).type_as(x) | |
return output * self.weight | |
class RoPE: | |
def __init__(self, dim: int, max_seq_len: int, base: float, dtype: torch.dtype): | |
self.dim = dim | |
self.max_seq_len = max_seq_len | |
self.base = base | |
self.dtype = dtype | |
self._precompute_freqs_cis() | |
def _precompute_freqs_cis(self): | |
freqs = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) | |
t = torch.arange(self.max_seq_len, device=freqs.device) | |
freqs_cis = torch.polar(torch.ones_like(t[:, None] * freqs[None, :]), t[:, None] * freqs[None, :]) | |
self.freqs_cis = freqs_cis | |
def __call__(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int): | |
seq_len = xq.shape[1] | |
freqs = self.freqs_cis[start_pos : start_pos + seq_len].to(xq.device) | |
# Perform rotation in float16/complex32 to avoid dtype mismatch issues. | |
# This is safer and has minimal precision impact for inference. | |
# Cast rotational frequencies to complex32 (backed by float16) | |
freqs_c32 = freqs.to(torch.complex32) | |
freqs_c32 = freqs_c32.view(1, seq_len, 1, -1) | |
# Reshape xq and xk to complex32 | |
# .reshape() is needed to pair up the last dimension for complex numbers | |
xq_c32 = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) | |
xk_c32 = torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) | |
# Apply RoPE and flatten back to original shape | |
xq_rotated = torch.view_as_real(xq_c32 * freqs_c32).flatten(3) | |
xk_rotated = torch.view_as_real(xk_c32 * freqs_c32).flatten(3) | |
# --- FIX --- | |
# Cast back to the original dtype to prevent type promotion issues. | |
return xq_rotated.to(xq.dtype), xk_rotated.to(xk.dtype) | |
class Attention(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.n_heads = args.n_heads | |
self.n_kv_heads = args.n_kv_heads | |
self.head_dim = args.dim // self.n_heads | |
self.n_rep = self.n_heads // self.n_kv_heads | |
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) | |
self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) | |
self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) | |
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) | |
# CORRECTED CODE | |
# Register cache tensors as buffers to ensure they are moved to the correct device | |
# along with the model. | |
self.register_buffer('cache_k', torch.zeros((1, args.max_seq_len, self.n_kv_heads, self.head_dim), dtype=args.dtype)) | |
self.register_buffer('cache_v', torch.zeros((1, args.max_seq_len, self.n_kv_heads, self.head_dim), dtype=args.dtype)) | |
def forward(self, x: torch.Tensor, start_pos: int, rope: RoPE): | |
bsz, seqlen, _ = x.shape | |
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
xq, xk = rope(xq, xk, start_pos) | |
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk | |
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv | |
keys = self.cache_k[:bsz, : start_pos + seqlen] | |
values = self.cache_v[:bsz, : start_pos + seqlen] | |
keys = keys[:, :, :, None, :].expand(bsz, -1, self.n_kv_heads, self.n_rep, self.head_dim).reshape(bsz, -1, self.n_heads, self.head_dim) | |
values = values[:, :, :, None, :].expand(bsz, -1, self.n_kv_heads, self.n_rep, self.head_dim).reshape(bsz, -1, self.n_heads, self.head_dim) | |
# Ensure query, key, and value tensors share the same dtype | |
tgt_dtype = keys.dtype # usually torch.float16 | |
if xq.dtype != tgt_dtype: | |
xq = xq.to(tgt_dtype) | |
if values.dtype != tgt_dtype: | |
values = values.to(tgt_dtype) | |
xq = xq.transpose(1, 2) | |
keys = keys.transpose(1, 2) | |
values = values.transpose(1, 2) | |
is_causal = seqlen > 1 | |
scores = torch.nn.functional.scaled_dot_product_attention(xq, keys, values, is_causal=is_causal) | |
output = scores.transpose(1, 2).contiguous().view(bsz, seqlen, -1) | |
# FIX: Ensure output dtype matches wo.weight dtype | |
return self.wo(output.to(self.wo.weight.dtype)) | |
class FeedForward(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False) | |
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False) | |
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False) | |
def forward(self, x): | |
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.attention = Attention(args) | |
self.feed_forward = FeedForward(args) | |
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
def forward(self, x: torch.Tensor, start_pos: int, rope: RoPE): | |
h = x + self.attention(self.attention_norm(x), start_pos, rope) | |
out = h + self.feed_forward(self.ffn_norm(h)) | |
return out | |
class Transformer(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) | |
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)]) | |
self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.output = nn.Linear(args.dim, args.vocab_size, bias=False) | |
self.rope = RoPE(args.rope_dim, args.max_seq_len, args.rope_freq_base, args.dtype) | |
def forward(self, tokens: torch.Tensor, start_pos: int): | |
h = self.tok_embeddings(tokens) | |
for layer in self.layers: | |
h = layer(h, start_pos, self.rope) | |
h = self.norm(h) | |
return self.output(h).float() | |
def reset_kv_cache(self): | |
for layer in self.layers: | |
layer.attention.cache_k.zero_() | |
layer.attention.cache_v.zero_() | |
class SimpleTokenizer: | |
def __init__(self, meta: Dict[str, Any]): | |
# Convert each token to bytes when initializing | |
self.tokens = [token.encode('utf-8') if isinstance(token, str) else token for token in meta['tokenizer.ggml.tokens']] | |
self.scores = meta.get('tokenizer.ggml.scores') | |
self.tok_map = {tok: i for i, tok in enumerate(self.tokens)} | |
self.bos_token_id = int(meta.get('tokenizer.bos_token_id', 1)) | |
self.eos_token_id = int(meta.get('tokenizer.eos_token_id', 2)) | |
self.unk_token_id = int(meta.get('tokenizer.unk_token_id', 0)) | |
def encode(self, text: str, bos: bool = True) -> List[int]: | |
token_ids = [] | |
if bos: | |
token_ids.append(self.bos_token_id) | |
for char_byte in text.encode('utf-8'): | |
char = bytes([char_byte]) | |
token_ids.append(self.tok_map.get(char, self.unk_token_id)) | |
return token_ids | |
def decode(self, token_ids: List[int]) -> str: | |
filtered_ids = [tid for tid in token_ids if tid not in (self.bos_token_id, self.eos_token_id)] | |
# This line will now work correctly as self.tokens will contain bytes objects | |
byte_tokens = b"".join(self.tokens[tid] for tid in filtered_ids if tid < len(self.tokens)) | |
return byte_tokens.decode('utf-8', errors='replace') | |
def map_gguf_weights_to_pytorch( | |
gguf_weights: Dict[str, torch.Tensor], args: ModelArgs | |
) -> Dict[str, torch.Tensor]: | |
pytorch_weights = {} | |
for name, tensor in gguf_weights.items(): | |
if name == "token_embd.weight": | |
pytorch_weights["tok_embeddings.weight"] = tensor | |
elif name == "output_norm.weight": | |
pytorch_weights["norm.weight"] = tensor | |
elif name == "output.weight" or name == "lm_head.weight": | |
pytorch_weights["output.weight"] = tensor | |
elif name.startswith("blk."): | |
parts = name.split('.') | |
layer_num = parts[1] | |
layer_type_full = ".".join(parts[2:]) | |
if layer_type_full == "attn_norm.weight": | |
new_name = f"layers.{layer_num}.attention_norm.weight" | |
elif layer_type_full == "ffn_norm.weight": | |
new_name = f"layers.{layer_num}.ffn_norm.weight" | |
elif layer_type_full == "attn_q.weight": | |
new_name = f"layers.{layer_num}.attention.wq.weight" | |
elif layer_type_full == "attn_k.weight": | |
new_name = f"layers.{layer_num}.attention.wk.weight" | |
elif layer_type_full == "attn_v.weight": | |
new_name = f"layers.{layer_num}.attention.wv.weight" | |
elif layer_type_full == "attn_output.weight": | |
new_name = f"layers.{layer_num}.attention.wo.weight" | |
elif layer_type_full == "ffn_gate.weight": | |
new_name = f"layers.{layer_num}.feed_forward.w1.weight" | |
elif layer_type_full == "ffn_down.weight": | |
new_name = f"layers.{layer_num}.feed_forward.w2.weight" | |
elif layer_type_full == "ffn_up.weight": | |
new_name = f"layers.{layer_num}.feed_forward.w3.weight" | |
else: | |
print(f"Warning: Unhandled GGUF block tensor: {name}") | |
continue | |
pytorch_weights[new_name] = tensor | |
else: | |
if name != "rope_freqs.weight": | |
print(f"Warning: Unhandled GGUF tensor (will be ignored): {name}") | |
if "output.weight" not in pytorch_weights and "tok_embeddings.weight" in pytorch_weights: | |
print("Info: Tying output.weight to tok_embeddings.weight.") | |
pytorch_weights["output.weight"] = pytorch_weights["tok_embeddings.weight"] | |
expected_keys = Transformer(args).state_dict().keys() | |
missing_keys = [k for k in expected_keys if k not in pytorch_weights] | |
if missing_keys: | |
print(f"Warning: The following keys were expected but not found: {missing_keys}") | |
return pytorch_weights | |
@torch.no_grad() | |
def generate( | |
model: Transformer, | |
tokenizer: SimpleTokenizer, | |
prompt: str, | |
max_tokens: int = 100 | |
): | |
model.eval() | |
model.reset_kv_cache() | |
token_ids = tokenizer.encode(prompt, bos=True) | |
if not token_ids: | |
print("\nPrompt is empty.") | |
return | |
# Determine the model's device and move the input tensor to it. | |
device = next(model.parameters()).device | |
prompt_tokens = torch.tensor([token_ids], dtype=torch.long, device=device) | |
start_time = time.time() | |
logits = model(prompt_tokens, 0) | |
prefill_time = time.time() - start_time | |
generated_ids = [] | |
# Subsequent tensors will be on the correct device because they come from the model's output | |
current_token = torch.argmax(logits[:, -1, :], dim=-1) | |
start_gen_time = time.time() | |
for i in range(max_tokens): | |
if current_token.item() == tokenizer.eos_token_id: | |
break | |
generated_ids.append(current_token.item()) | |
print(tokenizer.decode([current_token.item()]), end="", flush=True) | |
logits = model(current_token.view(1, 1), len(token_ids) + i) | |
current_token = torch.argmax(logits[:, -1, :], dim=-1) | |
gen_time = time.time() - start_gen_time | |
print("\n\n--- Stats ---") | |
prefill_tok_per_sec = len(token_ids) / prefill_time if prefill_time > 0 else float('inf') | |
gen_tok_per_sec = len(generated_ids) / gen_time if gen_time > 0 else float('inf') | |
print(f"Prefill time: {prefill_time:.2f}s ({prefill_tok_per_sec:.2f} t/s)") | |
print(f"Generation time: {gen_time:.2f}s ({gen_tok_per_sec:.2f} t/s)") | |
print(f"Total tokens generated: {len(generated_ids)}") | |
if __name__ == "__main__": | |
import sys | |
if not torch: | |
print("Error: 'torch' is not installed. Please run: pip install torch", file=sys.stderr) | |
sys.exit(1) | |
if not tqdm: | |
print("Note: 'tqdm' not found. To see a loading progress bar, run: pip install tqdm", file=sys.stderr) | |
# --- Configuration --- | |
# <<< PLEASE REPLACE THIS WITH THE PATH TO YOUR GGUF FILE >>> | |
GGUF_PATH = r'D:\ia\levi_chat\llm\Llama-3.2-1B-Instruct-Q4_K_M.gguf' | |
TARGET_DTYPE = torch.float16 | |
PROMPT = "The capital of France is" | |
MAX_TOKENS_TO_GENERATE = 100 | |
try: | |
# --- 1. Load GGUF file --- | |
print(f"Loading GGUF file: {GGUF_PATH}") | |
start_load_time = time.time() | |
loader = GGUFLoader(GGUF_PATH, target_dtype=TARGET_DTYPE) | |
meta, gguf_weights = loader.load() | |
print(f"GGUF loaded in {time.time() - start_load_time:.2f}s") | |
if 'general.architecture' not in meta or meta['general.architecture'] != 'llama': | |
print(f"Error: This script is designed for 'llama' architecture, " | |
f"but the model is '{meta.get('general.architecture', 'unknown')}'.", file=sys.stderr) | |
sys.exit(1) | |
# --- 2. Set up Model and Tokenizer --- | |
args = ModelArgs(meta, TARGET_DTYPE) | |
model = Transformer(args) | |
tokenizer = SimpleTokenizer(meta) | |
# Move model to a device if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
# --- 3. Map weights and load into model --- | |
print("\nMapping and loading weights into the PyTorch model...") | |
pytorch_weights = map_gguf_weights_to_pytorch(gguf_weights, args) | |
model.load_state_dict(pytorch_weights, strict=False) | |
print("Weights loaded successfully.") | |
# --- Print Model Info --- | |
print("\n--- Model Info ---") | |
print(f"- Architecture: {meta.get('general.architecture')}") | |
print(f"- Layers: {args.n_layers}, Heads: {args.n_heads}, Dim: {args.dim}") | |
print(f"- Target DType: {TARGET_DTYPE}") | |
print(f"- Vocab Size: {args.vocab_size}") | |
print(f"- Inference Device: {device}") | |
# --- 4. Run Inference --- | |
print(f"\n--- Inference ---") | |
print(f"Prompt: {PROMPT}") | |
print("Response:", end=" ", flush=True) | |
generate(model, tokenizer, PROMPT, max_tokens=MAX_TOKENS_TO_GENERATE) | |
print("\n\nSuccessfully loaded and ran GGUF file.") | |
except (FileNotFoundError, NotImplementedError, ValueError, KeyError) as e: | |
print(f"\nError: {e}", file=sys.stderr) | |
sys.exit(1) | |
except Exception as e: | |
import traceback | |
print(f"\nAn unexpected error occurred: {e}", file=sys.stderr) | |
traceback.print_exc() | |
sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment