mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
Speed up quant model loading and inference ...
... based on 3 evidences: 1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors. 2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view 3. “baking” model on GPU is significantly faster than computing on CPU when model load. mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants
This commit is contained in:
163
packages_3rdparty/gguf/quants.py
vendored
163
packages_3rdparty/gguf/quants.py
vendored
@@ -13,7 +13,7 @@ from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpac
|
||||
import numpy as np
|
||||
|
||||
|
||||
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=1)
|
||||
quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=-1)
|
||||
|
||||
|
||||
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
|
||||
@@ -90,8 +90,6 @@ class __Quant(ABC):
|
||||
grid_map: tuple[int | float, ...] = ()
|
||||
grid_hex: bytes | None = None
|
||||
|
||||
computation_dtype: torch.dtype = torch.bfloat16
|
||||
|
||||
def __init__(self):
|
||||
return TypeError("Quant conversion classes can't have instances")
|
||||
|
||||
@@ -144,29 +142,35 @@ class __Quant(ABC):
|
||||
return blocks.reshape(original_shape)
|
||||
|
||||
@classmethod
|
||||
def bake_layer(cls, layer, weight, computation_dtype):
|
||||
data = weight.data
|
||||
cls.computation_dtype = computation_dtype
|
||||
def bake(cls, parameter):
|
||||
if parameter.baked:
|
||||
return
|
||||
|
||||
data = parameter.data
|
||||
cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype]
|
||||
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
|
||||
n_blocks = rows.numel() // cls.type_size
|
||||
blocks = rows.reshape((n_blocks, cls.type_size))
|
||||
weight.data = blocks
|
||||
cls.bake_layer_weight(layer, weight)
|
||||
parameter.data = blocks.contiguous()
|
||||
cls.bake_inner(parameter)
|
||||
parameter.baked = True
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
def bake_inner(cls, parameter):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor:
|
||||
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent)
|
||||
return blocks.reshape(original_shape)
|
||||
def dequantize_pytorch(cls, x):
|
||||
if not x.baked:
|
||||
raise ValueError('GGUF Tensor is not baked!')
|
||||
|
||||
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x)
|
||||
return blocks.view(x.shape)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -303,22 +307,18 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
def bake_inner(cls, parameter):
|
||||
blocks = parameter.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
x = change_4bits_order(x)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
|
||||
x = change_4bits_order(x).view(torch.uint8)
|
||||
parameter.data = torch.cat([d, x], dim=-1).contiguous()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
d, qs = parent.quant_state_0, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
d, qs = quick_split(blocks, [2])
|
||||
d = d.view(parameter.computation_dtype)
|
||||
qs = quick_unpack_4bits(qs)
|
||||
return d * qs
|
||||
|
||||
@@ -381,30 +381,23 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
def bake_inner(cls, parameter):
|
||||
blocks = parameter.data
|
||||
|
||||
d, m, qs = quick_split(blocks, [2, 2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
|
||||
m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8)
|
||||
qs = change_4bits_order(qs).view(torch.uint8)
|
||||
|
||||
qs = change_4bits_order(qs)
|
||||
parameter.data = torch.cat([d, m, qs], dim=-1).contiguous()
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if m.device != qs.device:
|
||||
m = m.to(device=qs.device)
|
||||
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
d, m, qs = quick_split(blocks, [2, 2])
|
||||
d = d.view(parameter.computation_dtype)
|
||||
m = m.view(parameter.computation_dtype)
|
||||
qs = quick_unpack_4bits_u(qs)
|
||||
return (d * qs) + m
|
||||
|
||||
@@ -452,7 +445,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
@@ -461,7 +454,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qh, qs = quick_split(blocks, [2, 4])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
@@ -555,7 +548,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
@@ -564,8 +557,8 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qh, qs = quick_split(blocks, [2, 2, 4])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
m = m.view(torch.float16).to(parameter.computation_dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
@@ -603,23 +596,18 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
return (x * d)
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight):
|
||||
blocks = weight.data
|
||||
def bake_inner(cls, parameter):
|
||||
blocks = parameter.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
x = x.view(torch.int8)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8)
|
||||
parameter.data = torch.cat([d, x], dim=-1).contiguous()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
x = blocks
|
||||
d = parent.quant_state_0
|
||||
|
||||
if d.device != x.device:
|
||||
d = d.to(device=x.device)
|
||||
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
d, x = quick_split(blocks, [2])
|
||||
d = d.view(parameter.computation_dtype)
|
||||
return x * d
|
||||
|
||||
@classmethod
|
||||
@@ -660,12 +648,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
n_blocks = blocks.shape[0]
|
||||
scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
@@ -720,11 +708,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
n_blocks = blocks.shape[0]
|
||||
hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
@@ -801,42 +789,39 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def bake_layer_weight(cls, layer, weight): # Only compute one time when model load
|
||||
def bake_inner(cls, parameter): # Only compute one time when model load
|
||||
# Copyright Forge 2024
|
||||
|
||||
blocks = weight.data
|
||||
K_SCALE_SIZE = 12
|
||||
blocks = parameter.data
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
d, dmin, scales, qs = quick_split(blocks, [2, 2, cls.K_SCALE_SIZE])
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(parameter.computation_dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32))
|
||||
qs = change_4bits_order(qs)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
|
||||
d = d.view(torch.uint8).reshape((n_blocks, -1))
|
||||
dm = dm.view(torch.uint8).reshape((n_blocks, -1))
|
||||
qs = qs.view(torch.uint8)
|
||||
|
||||
parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous()
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# Compute in each diffusion iteration
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
d, dm, qs = quick_split(blocks, [16, 16])
|
||||
d = d.view(parameter.computation_dtype).view((n_blocks, -1, 1))
|
||||
dm = dm.view(parameter.computation_dtype).view((n_blocks, -1, 1))
|
||||
qs = quick_unpack_4bits_u(qs).view((n_blocks, -1, 32))
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
if dm.device != qs.device:
|
||||
dm = dm.to(device=qs.device)
|
||||
|
||||
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
@@ -867,14 +852,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
n_blocks = blocks.shape[0]
|
||||
d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
dmin = dmin.view(torch.float16).to(parameter.computation_dtype)
|
||||
sc, m = Q4_K.get_scale_min_pytorch(scales)
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
@@ -909,12 +894,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor:
|
||||
# Written by ChatGPT
|
||||
n_blocks = blocks.shape[0]
|
||||
ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16])
|
||||
scales = scales.view(torch.int8).to(cls.computation_dtype)
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
scales = scales.view(torch.int8).to(parameter.computation_dtype)
|
||||
d = d.view(torch.float16).to(parameter.computation_dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
Reference in New Issue
Block a user