reimplement q8/q85/q4 and review and match official gguf

This commit is contained in:
layerdiffusion
2024-08-15 02:41:15 -07:00
parent 447f261154
commit 2690b654fd
3 changed files with 70 additions and 74 deletions

View File

@@ -2,6 +2,13 @@ import gguf
import torch
quants_mapping = {
gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0,
gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0,
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
}
def functional_linear_gguf(x, weight, bias=None):
target_dtype = x.dtype
weight = dequantize_tensor(weight, target_dtype)
@@ -20,80 +27,9 @@ def dequantize_tensor(tensor, target_dtype=torch.float16):
if gguf_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]:
return data.to(target_dtype)
if gguf_type not in dequantize_functions:
if gguf_type not in quants_mapping:
raise NotImplementedError(f'Quant type {gguf_type} not implemented!')
return dequantize(data, gguf_type, gguf_real_shape).to(target_dtype)
quant_cls = quants_mapping.get(gguf_type)
def dequantize(data, qtype, oshape):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size)
return blocks.reshape(oshape)
def dequantize_blocks_Q8_0(blocks, block_size, type_size):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
d = blocks[:, :2].view(torch.float16)
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
return x * d
def dequantize_blocks_Q5_0(blocks, block_size, type_size):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
def to_uint32(x):
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
n_blocks = blocks.shape[0]
d = blocks[:, :2]
qh = blocks[:, 2:6]
qs = blocks[:, 6:]
d = d.view(torch.float16).to(torch.float32)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
def dequantize_blocks_Q4_0(blocks, block_size, type_size):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
qs = blocks[:, 2:]
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
dequantize_functions = {
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
}
return quant_cls.dequantize_pytorch(data, gguf_real_shape).to(target_dtype)

2
packages_3rdparty/gguf/README.md vendored Normal file
View File

@@ -0,0 +1,2 @@
This is Forge's implementation of GGUF - the difference is that it supports pytorch quant/dequant
Codes are based on LLama.cpp's GGUF - the difference is that it supports quant

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Callable, Sequence
from math import log2, ceil
import torch
from numpy.typing import DTypeLike
from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
@@ -123,6 +124,21 @@ class __Quant(ABC):
grid = np.take_along_axis(grid_map, grid, axis=-1)
cls.grid = grid.reshape((1, 1, *cls.grid_shape))
@classmethod
def dequantize_pytorch(cls, data: torch.Tensor, original_shape=torch.float16) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = cls.dequantize_blocks_pytorch(blocks, block_size, type_size)
return blocks.reshape(original_shape)
@classmethod
@abstractmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
raise NotImplementedError
@classmethod
@abstractmethod
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
@@ -251,6 +267,17 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
return (d * qs.astype(np.float32))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
qs = blocks[:, 2:]
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
@classmethod
@@ -331,6 +358,31 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
return (d * qs.astype(np.float32))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
n_blocks = blocks.shape[0]
d = blocks[:, :2]
qh = blocks[:, 2:6]
qs = blocks[:, 6:]
d = d.view(torch.float16).to(torch.float32)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
@classmethod
@@ -402,6 +454,12 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
return (x * d)
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
d = blocks[:, :2].view(torch.float16)
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
return x * d
class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
@classmethod