wip qx_1 loras

This commit is contained in:
layerdiffusion
2024-08-15 17:07:41 -07:00
parent 5fb67f49e8
commit 243952f364
2 changed files with 88 additions and 0 deletions

View File

@@ -4,7 +4,9 @@ import torch
quants_mapping = {
gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0,
gguf.GGMLQuantizationType.Q4_1: gguf.Q4_1,
gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0,
gguf.GGMLQuantizationType.Q5_1: gguf.Q5_1,
gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0,
}

View File

@@ -354,6 +354,43 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
return (d * qs) + m
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
m = blocks[:, 2:4].view(torch.float16)
qs = blocks[:, 4:]
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
# WIP
raise NotImplementedError('Q4_1 Lora is under construction!')
n_blocks = blocks.shape[0]
max_vals = blocks.max(dim=-1, keepdim=True).values
min_vals = blocks.min(dim=-1, keepdim=True).values
d = (max_vals - min_vals) / 15
id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d)
qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15)
qs = qs.view(n_blocks, 2, block_size // 2)
qs = qs[:, 0, :] | (qs[:, 1, :] << 4)
d = d.to(torch.float16).view(n_blocks, -1)
m = min_vals.to(torch.float16).view(n_blocks, -1)
return torch.cat([d, m, qs], dim=-1)
class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
@classmethod
@@ -503,6 +540,55 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
return (d * qs) + m
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
m = blocks[:, 2:4].view(torch.float16)
qh = blocks[:, 4:8]
qs = blocks[:, 8:]
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = (ql | (qh << 4))
return (d * qs) + m
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
# WIP
raise NotImplementedError('Q5_1 Lora is under construction!')
n_blocks = blocks.shape[0]
max_val = blocks.max(dim=-1, keepdim=True)[0]
min_val = blocks.min(dim=-1, keepdim=True)[0]
d = (max_val - min_val) / 31
id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d)
q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8)
qs = q.view(n_blocks, 2, block_size // 2)
qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4)
qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte()
d = d.to(torch.float16).view(-1, 2)
min_val = min_val.to(torch.float16).view(-1, 2)
return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1)
class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
@classmethod