mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
wip qx_1 loras
This commit is contained in:
@@ -4,7 +4,9 @@ import torch
|
||||
|
||||
quants_mapping = {
|
||||
gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0,
|
||||
gguf.GGMLQuantizationType.Q4_1: gguf.Q4_1,
|
||||
gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0,
|
||||
gguf.GGMLQuantizationType.Q5_1: gguf.Q5_1,
|
||||
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
|
||||
}
|
||||
|
||||
|
||||
86
packages_3rdparty/gguf/quants.py
vendored
86
packages_3rdparty/gguf/quants.py
vendored
@@ -354,6 +354,43 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
m = blocks[:, 2:4].view(torch.float16)
|
||||
qs = blocks[:, 4:]
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# WIP
|
||||
|
||||
raise NotImplementedError('Q4_1 Lora is under construction!')
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
max_vals = blocks.max(dim=-1, keepdim=True).values
|
||||
min_vals = blocks.min(dim=-1, keepdim=True).values
|
||||
|
||||
d = (max_vals - min_vals) / 15
|
||||
id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d)
|
||||
|
||||
qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15)
|
||||
|
||||
qs = qs.view(n_blocks, 2, block_size // 2)
|
||||
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
|
||||
|
||||
d = d.to(torch.float16).view(n_blocks, -1)
|
||||
m = min_vals.to(torch.float16).view(n_blocks, -1)
|
||||
|
||||
return torch.cat([d, m, qs], dim=-1)
|
||||
|
||||
|
||||
class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
@classmethod
|
||||
@@ -503,6 +540,55 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
def to_uint32(x):
|
||||
# pytorch uint32 by City96 - Apache-2.0
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = blocks[:, :2].view(torch.float16)
|
||||
m = blocks[:, 2:4].view(torch.float16)
|
||||
qh = blocks[:, 4:8]
|
||||
qs = blocks[:, 8:]
|
||||
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
||||
|
||||
qs = (ql | (qh << 4))
|
||||
return (d * qs) + m
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
|
||||
# WIP
|
||||
|
||||
raise NotImplementedError('Q5_1 Lora is under construction!')
|
||||
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
max_val = blocks.max(dim=-1, keepdim=True)[0]
|
||||
min_val = blocks.min(dim=-1, keepdim=True)[0]
|
||||
|
||||
d = (max_val - min_val) / 31
|
||||
id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d)
|
||||
q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8)
|
||||
|
||||
qs = q.view(n_blocks, 2, block_size // 2)
|
||||
qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4)
|
||||
|
||||
qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte()
|
||||
|
||||
d = d.to(torch.float16).view(-1, 2)
|
||||
min_val = min_val.to(torch.float16).view(-1, 2)
|
||||
|
||||
return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1)
|
||||
|
||||
|
||||
class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user