diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 98cd0ff5..c0d144d5 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -8,6 +8,7 @@ from numpy.typing import DTypeLike from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K from .lazy import LazyNumpyTensor +from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpack_4bits_u import numpy as np @@ -306,22 +307,20 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): blocks = weight.data d, x = quick_split(blocks, [2]) d = d.view(torch.float16).to(cls.computation_dtype) + x = change_4bits_order(x) weight.data = x layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) return @classmethod def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - n_blocks = blocks.shape[0] - d, qs = parent.quant_state_0, blocks if d.device != qs.device: d = d.to(device=qs.device) - qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) - qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 - return (d * qs) + qs = quick_unpack_4bits(qs) + return d * qs @classmethod def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: @@ -389,6 +388,8 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): d = d.view(torch.float16).to(cls.computation_dtype) m = m.view(torch.float16).to(cls.computation_dtype) + qs = change_4bits_order(qs) + weight.data = qs layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False) @@ -396,8 +397,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): @classmethod def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - n_blocks = blocks.shape[0] - d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks if d.device != qs.device: @@ -406,9 +405,7 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): if m.device != qs.device: m = m.to(device=qs.device) - qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) - qs = (qs & 0x0F).reshape(n_blocks, -1) - + qs = quick_unpack_4bits_u(qs) return (d * qs) + m @@ -817,6 +814,9 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype) + qs = qs.reshape((n_blocks, -1, 1, 32)) + qs = change_4bits_order(qs) + weight.data = qs layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False) @@ -836,8 +836,7 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): if dm.device != qs.device: dm = dm.to(device=qs.device) - qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) - qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) diff --git a/packages_3rdparty/gguf/quick_4bits_ops.py b/packages_3rdparty/gguf/quick_4bits_ops.py new file mode 100644 index 00000000..97404bbc --- /dev/null +++ b/packages_3rdparty/gguf/quick_4bits_ops.py @@ -0,0 +1,61 @@ +# By Forge + + +import torch + + +def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x): + x = x.view(torch.uint8).view(x.size(0), -1) + unpacked = torch.stack([x & 15, x >> 4], dim=-1) + reshaped = unpacked.view(x.size(0), -1) + reshaped = reshaped.to(torch.int8) - 8 + return reshaped.view(torch.int32) + + +def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x): + x = x.view(torch.uint8).view(x.size(0), -1) + unpacked = torch.stack([x & 15, x >> 4], dim=-1) + reshaped = unpacked.view(x.size(0), -1) + return reshaped.view(torch.int32) + + +native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] +native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0] + + +def quick_unpack_4bits(x): + global native_4bits_lookup_table + + s0 = x.size(0) + x = x.view(torch.uint16) + + if native_4bits_lookup_table.device != x.device: + native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device) + + y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten()) + y = y.view(torch.int8) + y = y.view(s0, -1) + + return y + + +def quick_unpack_4bits_u(x): + global native_4bits_lookup_table_u + + s0 = x.size(0) + x = x.view(torch.uint16) + + if native_4bits_lookup_table_u.device != x.device: + native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device) + + y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten()) + y = y.view(torch.uint8) + y = y.view(s0, -1) + + return y + + +def change_4bits_order(x): + y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1) + z = y[:, ::2] | (y[:, 1::2] << 4) + return z