mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-09 15:09:50 +00:00
Significantly speed up Q4_0, Q4_1, Q4_K
by precomputing all possible 4bit dequant into a lookup table and use pytorch indexing to get dequant, rather than really computing the bit operations. This should give very similar performance to native CUDA kernels, while being LoRA friendly and more flexiable
This commit is contained in:
23
packages_3rdparty/gguf/quants.py
vendored
23
packages_3rdparty/gguf/quants.py
vendored
@@ -8,6 +8,7 @@ from numpy.typing import DTypeLike
|
||||
|
||||
from .constants import GGML_QUANT_SIZES, GGMLQuantizationType, QK_K
|
||||
from .lazy import LazyNumpyTensor
|
||||
from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpack_4bits_u
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -306,22 +307,20 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
blocks = weight.data
|
||||
d, x = quick_split(blocks, [2])
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
x = change_4bits_order(x)
|
||||
weight.data = x
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = parent.quant_state_0, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
d = d.to(device=qs.device)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return (d * qs)
|
||||
qs = quick_unpack_4bits(qs)
|
||||
return d * qs
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
@@ -389,6 +388,8 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
d = d.view(torch.float16).to(cls.computation_dtype)
|
||||
m = m.view(torch.float16).to(cls.computation_dtype)
|
||||
|
||||
qs = change_4bits_order(qs)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
|
||||
@@ -396,8 +397,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
|
||||
|
||||
if d.device != qs.device:
|
||||
@@ -406,9 +405,7 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
if m.device != qs.device:
|
||||
m = m.to(device=qs.device)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = quick_unpack_4bits_u(qs)
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
@@ -817,6 +814,9 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32))
|
||||
qs = change_4bits_order(qs)
|
||||
|
||||
weight.data = qs
|
||||
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
|
||||
layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False)
|
||||
@@ -836,8 +836,7 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
|
||||
if dm.device != qs.device:
|
||||
dm = dm.to(device=qs.device)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32))
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
|
||||
61
packages_3rdparty/gguf/quick_4bits_ops.py
vendored
Normal file
61
packages_3rdparty/gguf/quick_4bits_ops.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
# By Forge
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
|
||||
x = x.view(torch.uint8).view(x.size(0), -1)
|
||||
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
|
||||
reshaped = unpacked.view(x.size(0), -1)
|
||||
reshaped = reshaped.to(torch.int8) - 8
|
||||
return reshaped.view(torch.int32)
|
||||
|
||||
|
||||
def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
|
||||
x = x.view(torch.uint8).view(x.size(0), -1)
|
||||
unpacked = torch.stack([x & 15, x >> 4], dim=-1)
|
||||
reshaped = unpacked.view(x.size(0), -1)
|
||||
return reshaped.view(torch.int32)
|
||||
|
||||
|
||||
native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
|
||||
native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
|
||||
|
||||
|
||||
def quick_unpack_4bits(x):
|
||||
global native_4bits_lookup_table
|
||||
|
||||
s0 = x.size(0)
|
||||
x = x.view(torch.uint16)
|
||||
|
||||
if native_4bits_lookup_table.device != x.device:
|
||||
native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device)
|
||||
|
||||
y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten())
|
||||
y = y.view(torch.int8)
|
||||
y = y.view(s0, -1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def quick_unpack_4bits_u(x):
|
||||
global native_4bits_lookup_table_u
|
||||
|
||||
s0 = x.size(0)
|
||||
x = x.view(torch.uint16)
|
||||
|
||||
if native_4bits_lookup_table_u.device != x.device:
|
||||
native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device)
|
||||
|
||||
y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten())
|
||||
y = y.view(torch.uint8)
|
||||
y = y.view(s0, -1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def change_4bits_order(x):
|
||||
y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1)
|
||||
z = y[:, ::2] | (y[:, 1::2] << 4)
|
||||
return z
|
||||
Reference in New Issue
Block a user