From 243952f3642162b27ab2f2561048020eb162a5d8 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 15 Aug 2024 17:07:41 -0700 Subject: [PATCH] wip qx_1 loras --- backend/operations_gguf.py | 2 + packages_3rdparty/gguf/quants.py | 86 ++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index abd52a53..fdd565ec 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -4,7 +4,9 @@ import torch quants_mapping = { gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0, + gguf.GGMLQuantizationType.Q4_1: gguf.Q4_1, gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0, + gguf.GGMLQuantizationType.Q5_1: gguf.Q5_1, gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0, } diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index bc434438..e861baf9 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -354,6 +354,43 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): return (d * qs) + m + @classmethod + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + n_blocks = blocks.shape[0] + + d = blocks[:, :2].view(torch.float16) + m = blocks[:, 2:4].view(torch.float16) + qs = blocks[:, 4:] + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + @classmethod + def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + # WIP + + raise NotImplementedError('Q4_1 Lora is under construction!') + + n_blocks = blocks.shape[0] + + max_vals = blocks.max(dim=-1, keepdim=True).values + min_vals = blocks.min(dim=-1, keepdim=True).values + + d = (max_vals - min_vals) / 15 + id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d) + + qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15) + + qs = qs.view(n_blocks, 2, block_size // 2) + qs = qs[:, 0, :] | (qs[:, 1, :] << 4) + + d = d.to(torch.float16).view(n_blocks, -1) + m = min_vals.to(torch.float16).view(n_blocks, -1) + + return torch.cat([d, m, qs], dim=-1) + class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): @classmethod @@ -503,6 +540,55 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): return (d * qs) + m + @classmethod + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def to_uint32(x): + # pytorch uint32 by City96 - Apache-2.0 + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + n_blocks = blocks.shape[0] + + d = blocks[:, :2].view(torch.float16) + m = blocks[:, 2:4].view(torch.float16) + qh = blocks[:, 4:8] + qs = blocks[:, 8:] + + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = (ql | (qh << 4)) + return (d * qs) + m + + @classmethod + def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + # WIP + + raise NotImplementedError('Q5_1 Lora is under construction!') + + n_blocks = blocks.shape[0] + + max_val = blocks.max(dim=-1, keepdim=True)[0] + min_val = blocks.min(dim=-1, keepdim=True)[0] + + d = (max_val - min_val) / 31 + id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d) + q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8) + + qs = q.view(n_blocks, 2, block_size // 2) + qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4) + + qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte() + + d = d.to(torch.float16).view(-1, 2) + min_val = min_val.to(torch.float16).view(-1, 2) + + return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1) + class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): @classmethod