Significantly speed up Q4_0, Q4_1, Q4_K

by precomputing all possible 4bit dequant into a lookup table and use pytorch indexing to get dequant, rather than really computing the bit operations.
This should give very similar performance to native CUDA kernels, while being LoRA friendly and more flexiable
This commit is contained in:
layerdiffusion
2024-08-25 16:49:23 -07:00
parent e60bb1c96f
commit 82dfc2b15b
2 changed files with 72 additions and 12 deletions

View File

@@ -8,6 +8,7 @@ from numpy.typing import DTypeLike
from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
from .lazy import LazyNumpyTensor
from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpack_4bits_u
import numpy as np
@@ -306,22 +307,20 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
blocks = weight.data
d, x = quick_split(blocks, [2])
d = d.view(torch.float16).to(cls.computation_dtype)
x = change_4bits_order(x)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, qs = parent.quant_state_0, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return (d * qs)
qs = quick_unpack_4bits(qs)
return d * qs
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
@@ -389,6 +388,8 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
qs = change_4bits_order(qs)
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
@@ -396,8 +397,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
n_blocks = blocks.shape[0]
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
if d.device != qs.device:
@@ -406,9 +405,7 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
if m.device != qs.device:
m = m.to(device=qs.device)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
qs = quick_unpack_4bits_u(qs)
return (d * qs) + m
@@ -817,6 +814,9 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
qs = qs.reshape((n_blocks, -1, 1, 32))
qs = change_4bits_order(qs)
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
@@ -836,8 +836,7 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
if dm.device != qs.device:
dm = dm.to(device=qs.device)
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))

View File

@@ -0,0 +1,61 @@
# By Forge
import torch
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
x = x.view(torch.uint8).view(x.size(0), -1)
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
reshaped = unpacked.view(x.size(0), -1)
reshaped = reshaped.to(torch.int8) - 8
return reshaped.view(torch.int32)
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
x = x.view(torch.uint8).view(x.size(0), -1)
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
reshaped = unpacked.view(x.size(0), -1)
return reshaped.view(torch.int32)
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
def quick_unpack_4bits(x):
global native_4bits_lookup_table
s0 = x.size(0)
x = x.view(torch.uint16)
if native_4bits_lookup_table.device != x.device:
native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device)
y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten())
y = y.view(torch.int8)
y = y.view(s0, -1)
return y
def quick_unpack_4bits_u(x):
global native_4bits_lookup_table_u
s0 = x.size(0)
x = x.view(torch.uint16)
if native_4bits_lookup_table_u.device != x.device:
native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device)
y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten())
y = y.view(torch.uint8)
y = y.view(s0, -1)
return y
def change_4bits_order(x):
y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1)
z = y[:, ::2] | (y[:, 1::2] << 4)
return z