Speed up quant model loading and inference ...

... based on 3 evidences:
1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors.
2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view
3. “baking” model on GPU is significantly faster than computing on CPU when model load.

mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
This commit is contained in:
layerdiffusion
2024-08-30 00:49:05 -07:00
parent 3d62fa9598
commit 4c9380c46a
7 changed files with 126 additions and 181 deletions

View File

@@ -13,7 +13,7 @@ from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpac
import numpy as np
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=1)
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=-1)
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
@@ -90,8 +90,6 @@ class __Quant(ABC):
grid_map: tuple[int | float, ...] = ()
grid_hex: bytes | None = None
computation_dtype: torch.dtype = torch.bfloat16
def __init__(self):
return TypeError("Quant conversion classes can't have instances")
@@ -144,29 +142,35 @@ class __Quant(ABC):
return blocks.reshape(original_shape)
@classmethod
def bake_layer(cls, layer, weight, computation_dtype):
data = weight.data
cls.computation_dtype = computation_dtype
def bake(cls, parameter):
if parameter.baked:
return
data = parameter.data
cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // cls.type_size
blocks = rows.reshape((n_blocks, cls.type_size))
weight.data = blocks
cls.bake_layer_weight(layer, weight)
parameter.data = blocks.contiguous()
cls.bake_inner(parameter)
parameter.baked = True
return
@classmethod
def bake_layer_weight(cls, layer, weight):
def bake_inner(cls, parameter):
pass
@classmethod
def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor:
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent)
return blocks.reshape(original_shape)
def dequantize_pytorch(cls, x):
if not x.baked:
raise ValueError('GGUF Tensor is not baked!')
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x)
return blocks.view(x.shape)
@classmethod
@abstractmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
raise NotImplementedError
@classmethod
@@ -303,22 +307,18 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
return (d * qs.astype(np.float32))
@classmethod
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
def bake_inner(cls, parameter):
blocks = parameter.data
d, x = quick_split(blocks, [2])
d = d.view(torch.float16).to(cls.computation_dtype)
x = change_4bits_order(x)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
x = change_4bits_order(x).view(torch.uint8)
parameter.data = torch.cat([d, x], dim=-1).contiguous()
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
d, qs = parent.quant_state_0, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
d, qs = quick_split(blocks, [2])
d = d.view(parameter.computation_dtype)
qs = quick_unpack_4bits(qs)
return d * qs
@@ -381,30 +381,23 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
return (d * qs) + m
@classmethod
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
def bake_inner(cls, parameter):
blocks = parameter.data
d, m, qs = quick_split(blocks, [2, 2])
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
qs = change_4bits_order(qs).view(torch.uint8)
qs = change_4bits_order(qs)
parameter.data = torch.cat([d, m, qs], dim=-1).contiguous()
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
if m.device != qs.device:
m = m.to(device=qs.device)
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
d, m, qs = quick_split(blocks, [2, 2])
d = d.view(parameter.computation_dtype)
m = m.view(parameter.computation_dtype)
qs = quick_unpack_4bits_u(qs)
return (d * qs) + m
@@ -452,7 +445,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
return (d * qs.astype(np.float32))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
@@ -461,7 +454,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
n_blocks = blocks.shape[0]
d, qh, qs = quick_split(blocks, [2, 4])
d = d.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
@@ -555,7 +548,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
return (d * qs) + m
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
@@ -564,8 +557,8 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
n_blocks = blocks.shape[0]
d, m, qh, qs = quick_split(blocks, [2, 2, 4])
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
m = m.view(torch.float16).to(parameter.computation_dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
@@ -603,23 +596,18 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
return (x * d)
@classmethod
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
def bake_inner(cls, parameter):
blocks = parameter.data
d, x = quick_split(blocks, [2])
x = x.view(torch.int8)
d = d.view(torch.float16).to(cls.computation_dtype)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8)
parameter.data = torch.cat([d, x], dim=-1).contiguous()
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
x = blocks
d = parent.quant_state_0
if d.device != x.device:
d = d.to(device=x.device)
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
d, x = quick_split(blocks, [2])
d = d.view(parameter.computation_dtype)
return x * d
@classmethod
@@ -660,12 +648,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
return qs.reshape((n_blocks, -1))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2])
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
@@ -720,11 +708,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
return (dl * q).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12])
d = d.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
lscales = lscales.reshape((n_blocks, 16))
@@ -801,42 +789,39 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
return (d * qs - dm).reshape((n_blocks, QK_K))
@classmethod
def bake_layer_weight(cls, layer, weight): # Only compute one time when model load
def bake_inner(cls, parameter): # Only compute one time when model load
# Copyright Forge 2024
blocks = weight.data
K_SCALE_SIZE = 12
blocks = parameter.data
n_blocks = blocks.shape[0]
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
d, dmin, scales, qs = quick_split(blocks, [2, 2, cls.K_SCALE_SIZE])
d = d.view(torch.float16).to(parameter.computation_dtype)
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
sc, m = Q4_K.get_scale_min_pytorch(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(parameter.computation_dtype)
qs = qs.reshape((n_blocks, -1, 1, 32))
qs = change_4bits_order(qs)
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
d = d.view(torch.uint8).reshape((n_blocks, -1))
dm = dm.view(torch.uint8).reshape((n_blocks, -1))
qs = qs.view(torch.uint8)
parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous()
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# Compute in each diffusion iteration
n_blocks = blocks.shape[0]
d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks
d, dm, qs = quick_split(blocks, [16, 16])
d = d.view(parameter.computation_dtype).view((n_blocks, -1, 1))
dm = dm.view(parameter.computation_dtype).view((n_blocks, -1, 1))
qs = quick_unpack_4bits_u(qs).view((n_blocks, -1, 32))
if d.device != qs.device:
d = d.to(device=qs.device)
if dm.device != qs.device:
dm = dm.to(device=qs.device)
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
@@ -867,14 +852,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
return (d * q - dm).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
QK_K = 256
K_SCALE_SIZE = 12
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8])
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
sc, m = Q4_K.get_scale_min_pytorch(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
@@ -909,12 +894,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
return (d * q).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
# Written by ChatGPT
n_blocks = blocks.shape[0]
ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16])
scales = scales.view(torch.int8).to(cls.computation_dtype)
d = d.view(torch.float16).to(cls.computation_dtype)
scales = scales.view(torch.int8).to(parameter.computation_dtype)
d = d.view(torch.float16).to(parameter.computation_dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))