Files
stable-diffusion-webui-forge/backend/operations_gguf.py
layerdiffusion d8b83a9501 gguf preview
2024-08-15 00:03:32 -07:00

110 lines
3.3 KiB
Python

import gguf
import torch
def functional_linear_gguf(x, weight, bias=None):
target_dtype = x.dtype
weight = dequantize_tensor(weight, target_dtype)
bias = dequantize_tensor(bias, target_dtype)
return torch.nn.functional.linear(x, weight, bias)
def dequantize_tensor(tensor, dtype=torch.float16):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
if tensor is None:
return None
data = torch.tensor(tensor.data)
qtype = tensor.gguf_type
oshape = tensor.gguf_real_shape
if qtype == gguf.GGMLQuantizationType.F32:
return data.to(dtype)
elif qtype == gguf.GGMLQuantizationType.F16:
return data.to(dtype)
elif qtype in dequantize_functions:
# this is the main pytorch op
return dequantize(data, qtype, oshape).to(dtype)
else:
# this is incredibly slow
new = gguf.quants.dequantize(data.cpu().numpy(), qtype)
return torch.from_numpy(new).to(data.device, dtype=dtype)
def dequantize(data, qtype, oshape):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size)
return blocks.reshape(oshape)
def to_uint32(x):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
# no uint32 :(
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def dequantize_blocks_Q8_0(blocks, block_size, type_size):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
d = blocks[:, :2].view(torch.float16)
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
return (x * d)
def dequantize_blocks_Q5_0(blocks, block_size, type_size):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
d = blocks[:, :2]
qh = blocks[:, 2:6]
qs = blocks[:, 6:]
d = d.view(torch.float16).to(torch.float32)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return (d * qs)
def dequantize_blocks_Q4_0(blocks, block_size, type_size):
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
qs = blocks[:, 2:]
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return (d * qs)
dequantize_functions = {
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
}