From e60bb1c96fbcc257a4dbfc8d212df24a363cf379 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:02:54 -0700 Subject: [PATCH] Make Q4_K_S as fast as Q4_0 by baking the layer when model load --- packages_3rdparty/gguf/quants.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index dcea8f5a..98cd0ff5 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -804,9 +804,10 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): return (d * qs - dm).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) - QK_K = 256 + def bake_layer_weight(cls, layer, weight): # Only compute one time when model load + # Copyright Forge 2024 + + blocks = weight.data K_SCALE_SIZE = 12 n_blocks = blocks.shape[0] d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE]) @@ -814,7 +815,27 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): dmin = dmin.view(torch.float16).to(cls.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) - dm = (dmin * m).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype) + + weight.data = qs + layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False) + return + + @classmethod + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + # Compute in each diffusion iteration + + n_blocks = blocks.shape[0] + + d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks + + if d.device != qs.device: + d = d.to(device=qs.device) + + if dm.device != qs.device: + dm = dm.to(device=qs.device) + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K))