revise GGUF by precomputing some parameters

rather than computing them in each diffusion iteration
This commit is contained in:
layerdiffusion
2024-08-25 14:26:46 -07:00
parent ba01ad3711
commit 13d6f8ed90
5 changed files with 137 additions and 48 deletions

View File

@@ -162,6 +162,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
model.initial_device = initial_device
model.offload_device = offload_device
if storage_dtype in ['gguf']:
from backend.operations_gguf import bake_gguf_model
model = bake_gguf_model(model)
return model
print(f'Skipped: {component_name} = {lib_name}.{cls_name}')

View File

@@ -405,12 +405,24 @@ class ForgeOperationsGGUF(ForgeOperations):
self.weight = state_dict[prefix + 'weight']
if prefix + 'bias' in state_dict:
self.bias = state_dict[prefix + 'bias']
if self.weight is not None and hasattr(self.weight, 'parent'):
self.weight.parent = self
if self.bias is not None and hasattr(self.bias, 'parent'):
self.bias.parent = self
return
def _apply(self, fn, recurse=True):
if self.weight is not None:
self.weight = utils.tensor2parameter(fn(self.weight))
if self.bias is not None:
self.bias = utils.tensor2parameter(fn(self.bias))
for i in range(5):
quant_state_name = f'quant_state_{i}'
quant_state = getattr(self, quant_state_name, None)
if quant_state is not None:
quant_state = fn(quant_state)
quant_state = utils.tensor2parameter(quant_state)
setattr(self, quant_state_name, quant_state)
return self
def forward(self, x):

View File

@@ -27,6 +27,7 @@ class ParameterGGUF(torch.nn.Parameter):
self.gguf_type = tensor.tensor_type
self.gguf_real_shape = torch.Size(reversed(list(tensor.shape)))
self.gguf_cls = quants_mapping.get(self.gguf_type, None)
self.parent = None
@property
def shape(self):
@@ -43,6 +44,7 @@ class ParameterGGUF(torch.nn.Parameter):
new.gguf_type = self.gguf_type
new.gguf_real_shape = self.gguf_real_shape
new.gguf_cls = self.gguf_cls
new.parent = self.parent
return new
def pin_memory(self, device=None):
@@ -50,17 +52,38 @@ class ParameterGGUF(torch.nn.Parameter):
new.gguf_type = self.gguf_type
new.gguf_real_shape = self.gguf_real_shape
new.gguf_cls = self.gguf_cls
new.parent = self.parent
return new
@classmethod
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape):
def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent):
new = ParameterGGUF(data, no_init=True)
new.gguf_type = gguf_type
new.gguf_real_shape = gguf_real_shape
new.gguf_cls = gguf_cls
new.parent = parent
return new
def bake_gguf_model(model):
computation_dtype = model.computation_dtype
backed_layer_counter = 0
for m in model.modules():
if hasattr(m, 'weight'):
weight = m.weight
if hasattr(weight, 'gguf_cls'):
gguf_cls = weight.gguf_cls
if gguf_cls is not None:
backed_layer_counter += 1
gguf_cls.bake_layer(m, weight, computation_dtype)
if backed_layer_counter > 0:
print(f'GGUF backed {backed_layer_counter} layers.')
return model
def dequantize_tensor(tensor):
if tensor is None:
return None
@@ -68,7 +91,7 @@ def dequantize_tensor(tensor):
if not hasattr(tensor, 'gguf_cls'):
return tensor
data = torch.tensor(tensor.data)
data = tensor
gguf_cls = tensor.gguf_cls
gguf_real_shape = tensor.gguf_real_shape

View File

@@ -425,7 +425,8 @@ class LoraLoader:
data=weight,
gguf_type=gguf_type,
gguf_cls=gguf_cls,
gguf_real_shape=gguf_real_shape
gguf_real_shape=gguf_real_shape,
parent=parent_layer
))
continue

View File

@@ -89,6 +89,8 @@ class __Quant(ABC):
grid_map: tuple[int | float, ...] = ()
grid_hex: bytes | None = None
computation_dtype: torch.dtype = torch.bfloat16
def __init__(self):
return TypeError("Quant conversion classes can't have instances")
@@ -141,18 +143,29 @@ class __Quant(ABC):
return blocks.reshape(original_shape)
@classmethod
def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
block_size, type_size = GGML_QUANT_SIZES[cls.qtype]
def bake_layer(cls, layer, weight, computation_dtype):
data = weight.data
cls.computation_dtype = computation_dtype
cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = cls.dequantize_blocks_pytorch(blocks, block_size, type_size)
n_blocks = rows.numel() // cls.type_size
blocks = rows.reshape((n_blocks, cls.type_size))
weight.data = blocks
cls.bake_layer_weight(layer, weight)
return
@classmethod
def bake_layer_weight(cls, layer, weight):
pass
@classmethod
def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor:
blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent)
return blocks.reshape(original_shape)
@classmethod
@abstractmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
raise NotImplementedError
@classmethod
@@ -289,15 +302,26 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
return (d * qs.astype(np.float32))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
d, x = quick_split(blocks, [2])
d = d.view(torch.float16).to(cls.computation_dtype)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
qs = blocks[:, 2:]
d, qs = parent.quant_state_0, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
return (d * qs)
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
@@ -358,12 +382,29 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1):
return (d * qs) + m
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
d, m, qs = quick_split(blocks, [2, 2])
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
weight.data = qs
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False)
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
m = blocks[:, 2:4].view(torch.float16)
qs = blocks[:, 4:]
d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks
if d.device != qs.device:
d = d.to(device=qs.device)
if m.device != qs.device:
m = m.to(device=qs.device)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
@@ -414,7 +455,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
return (d * qs.astype(np.float32))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
@@ -422,11 +463,8 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
n_blocks = blocks.shape[0]
d = blocks[:, :2]
qh = blocks[:, 2:6]
qs = blocks[:, 6:]
d = d.view(torch.float16).to(torch.float32)
d, qh, qs = quick_split(blocks, [2, 4])
d = d.view(torch.float16).to(cls.computation_dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
@@ -436,7 +474,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
return (d * qs)
@classmethod
def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
@@ -520,7 +558,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
return (d * qs) + m
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
def to_uint32(x):
# pytorch uint32 by City96 - Apache-2.0
x = x.view(torch.uint8).to(torch.int32)
@@ -528,11 +566,9 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1):
n_blocks = blocks.shape[0]
d = blocks[:, :2].view(torch.float16)
m = blocks[:, 2:4].view(torch.float16)
qh = blocks[:, 4:8]
qs = blocks[:, 8:]
d, m, qh, qs = quick_split(blocks, [2, 2, 4])
d = d.view(torch.float16).to(cls.computation_dtype)
m = m.view(torch.float16).to(cls.computation_dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
@@ -570,9 +606,22 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
return (x * d)
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
d = blocks[:, :2].view(torch.float16)
x = blocks[:, 2:].view(torch.int8).to(torch.float16)
def bake_layer_weight(cls, layer, weight):
blocks = weight.data
d, x = quick_split(blocks, [2])
d = d.view(torch.float16).to(cls.computation_dtype)
weight.data = x
layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False)
return
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
x = blocks
d = parent.quant_state_0
if d.device != x.device:
d = d.to(device=x.device)
return x * d
@classmethod
@@ -613,12 +662,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K):
return qs.reshape((n_blocks, -1))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2])
d = d.view(torch.float16)
dmin = dmin.view(torch.float16)
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
@@ -673,11 +722,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
return (dl * q).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
n_blocks = blocks.shape[0]
hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12])
d = d.view(torch.float16)
d = d.view(torch.float16).to(cls.computation_dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
lscales = lscales.reshape((n_blocks, 16))
@@ -754,14 +803,14 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
return (d * qs - dm).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
QK_K = 256
K_SCALE_SIZE = 12
n_blocks = blocks.shape[0]
d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE])
d = d.view(torch.float16)
dmin = dmin.view(torch.float16)
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
sc, m = Q4_K.get_scale_min_pytorch(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
@@ -797,14 +846,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):
return (d * q - dm).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
QK_K = 256
K_SCALE_SIZE = 12
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8])
d = d.view(torch.float16)
dmin = dmin.view(torch.float16)
d = d.view(torch.float16).to(cls.computation_dtype)
dmin = dmin.view(torch.float16).to(cls.computation_dtype)
sc, m = Q4_K.get_scale_min_pytorch(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
@@ -839,12 +888,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K):
return (d * q).reshape((n_blocks, QK_K))
@classmethod
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor:
def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor:
# Written by ChatGPT
n_blocks = blocks.shape[0]
ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16])
scales = scales.view(torch.int8)
d = d.view(torch.float16)
scales = scales.view(torch.int8).to(cls.computation_dtype)
d = d.view(torch.float16).to(cls.computation_dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))