From 82dfc2b15be168c43e4d65585343b3911f561297 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sun, 25 Aug 2024 16:49:23 -0700 Subject: [PATCH] Significantly speed up Q4_0, Q4_1, Q4_K by precomputing all possible 4bit dequant into a lookup table and use pytorch indexing to get dequant, rather than really computing the bit operations. This should give very similar performance to native CUDA kernels, while being LoRA friendly and more flexiable --- packages_3rdparty/gguf/quants.py | 23 ++++----- packages_3rdparty/gguf/quick_4bits_ops.py | 61 +++++++++++++++++++++++ 2 files changed, 72 insertions(+), 12 deletions(-) create mode 100644 packages_3rdparty/gguf/quick_4bits_ops.py diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 98cd0ff5..c0d144d5 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -8,6 +8,7 @@ from numpy.typing import DTypeLike from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K from .lazy import LazyNumpyTensor +from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpack_4bits_u import numpy as np @@ -306,22 +307,20 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): blocks = weight.data d, x = quick_split(blocks, [2]) d = d.view(torch.float16).to(cls.computation_dtype) + x = change_4bits_order(x) weight.data = x layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) return @classmethod def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - n_blocks = blocks.shape[0] - d, qs = parent.quant_state_0, blocks if d.device != qs.device: d = d.to(device=qs.device) - qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) - qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 - return (d * qs) + qs = quick_unpack_4bits(qs) + return d * qs @classmethod def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: @@ -389,6 +388,8 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): d = d.view(torch.float16).to(cls.computation_dtype) m = m.view(torch.float16).to(cls.computation_dtype) + qs = change_4bits_order(qs) + weight.data = qs layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False) @@ -396,8 +397,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): @classmethod def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - n_blocks = blocks.shape[0] - d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks if d.device != qs.device: @@ -406,9 +405,7 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): if m.device != qs.device: m = m.to(device=qs.device) - qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) - qs = (qs & 0x0F).reshape(n_blocks, -1) - + qs = quick_unpack_4bits_u(qs) return (d * qs) + m @@ -817,6 +814,9 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype) + qs = qs.reshape((n_blocks, -1, 1, 32)) + qs = change_4bits_order(qs) + weight.data = qs layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False) @@ -836,8 +836,7 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): if dm.device != qs.device: dm = dm.to(device=qs.device) - qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) - qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) diff --git a/packages_3rdparty/gguf/quick_4bits_ops.py b/packages_3rdparty/gguf/quick_4bits_ops.py new file mode 100644 index 00000000..97404bbc --- /dev/null +++ b/packages_3rdparty/gguf/quick_4bits_ops.py @@ -0,0 +1,61 @@ +# By Forge + + +import torch + + +def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x): + x = x.view(torch.uint8).view(x.size(0), -1) + unpacked = torch.stack([x & 15, x >> 4], dim=-1) + reshaped = unpacked.view(x.size(0), -1) + reshaped = reshaped.to(torch.int8) - 8 + return reshaped.view(torch.int32) + + +def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x): + x = x.view(torch.uint8).view(x.size(0), -1) + unpacked = torch.stack([x & 15, x >> 4], dim=-1) + reshaped = unpacked.view(x.size(0), -1) + return reshaped.view(torch.int32) + + +native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] +native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] + + +def quick_unpack_4bits(x): + global native_4bits_lookup_table + + s0 = x.size(0) + x = x.view(torch.uint16) + + if native_4bits_lookup_table.device != x.device: + native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device) + + y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten()) + y = y.view(torch.int8) + y = y.view(s0, -1) + + return y + + +def quick_unpack_4bits_u(x): + global native_4bits_lookup_table_u + + s0 = x.size(0) + x = x.view(torch.uint16) + + if native_4bits_lookup_table_u.device != x.device: + native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device) + + y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten()) + y = y.view(torch.uint8) + y = y.view(s0, -1) + + return y + + +def change_4bits_order(x): + y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1) + z = y[:, ::2] | (y[:, 1::2] << 4) + return z