diff --git a/backend/loader.py b/backend/loader.py index 24dc4b40..9154c92e 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -162,6 +162,10 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p model.initial_device = initial_device model.offload_device = offload_device + if storage_dtype in ['gguf']: + from backend.operations_gguf import bake_gguf_model + model = bake_gguf_model(model) + return model print(f'Skipped: {component_name} = {lib_name}.{cls_name}') diff --git a/backend/operations.py b/backend/operations.py index 72cbfc0d..b3d34b82 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -405,12 +405,24 @@ class ForgeOperationsGGUF(ForgeOperations): self.weight = state_dict[prefix + 'weight'] if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'] + if self.weight is not None and hasattr(self.weight, 'parent'): + self.weight.parent = self + if self.bias is not None and hasattr(self.bias, 'parent'): + self.bias.parent = self + return def _apply(self, fn, recurse=True): if self.weight is not None: self.weight = utils.tensor2parameter(fn(self.weight)) if self.bias is not None: self.bias = utils.tensor2parameter(fn(self.bias)) + for i in range(5): + quant_state_name = f'quant_state_{i}' + quant_state = getattr(self, quant_state_name, None) + if quant_state is not None: + quant_state = fn(quant_state) + quant_state = utils.tensor2parameter(quant_state) + setattr(self, quant_state_name, quant_state) return self def forward(self, x): diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 72da4604..5e190b40 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -27,6 +27,7 @@ class ParameterGGUF(torch.nn.Parameter): self.gguf_type = tensor.tensor_type self.gguf_real_shape = torch.Size(reversed(list(tensor.shape))) self.gguf_cls = quants_mapping.get(self.gguf_type, None) + self.parent = None @property def shape(self): @@ -43,6 +44,7 @@ class ParameterGGUF(torch.nn.Parameter): new.gguf_type = self.gguf_type new.gguf_real_shape = self.gguf_real_shape new.gguf_cls = self.gguf_cls + new.parent = self.parent return new def pin_memory(self, device=None): @@ -50,17 +52,38 @@ class ParameterGGUF(torch.nn.Parameter): new.gguf_type = self.gguf_type new.gguf_real_shape = self.gguf_real_shape new.gguf_cls = self.gguf_cls + new.parent = self.parent return new @classmethod - def make(cls, data, gguf_type, gguf_cls, gguf_real_shape): + def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent): new = ParameterGGUF(data, no_init=True) new.gguf_type = gguf_type new.gguf_real_shape = gguf_real_shape new.gguf_cls = gguf_cls + new.parent = parent return new +def bake_gguf_model(model): + computation_dtype = model.computation_dtype + backed_layer_counter = 0 + + for m in model.modules(): + if hasattr(m, 'weight'): + weight = m.weight + if hasattr(weight, 'gguf_cls'): + gguf_cls = weight.gguf_cls + if gguf_cls is not None: + backed_layer_counter += 1 + gguf_cls.bake_layer(m, weight, computation_dtype) + + if backed_layer_counter > 0: + print(f'GGUF backed {backed_layer_counter} layers.') + + return model + + def dequantize_tensor(tensor): if tensor is None: return None @@ -68,7 +91,7 @@ def dequantize_tensor(tensor): if not hasattr(tensor, 'gguf_cls'): return tensor - data = torch.tensor(tensor.data) + data = tensor gguf_cls = tensor.gguf_cls gguf_real_shape = tensor.gguf_real_shape diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 83da8fe5..fdd6f67d 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -425,7 +425,8 @@ class LoraLoader: data=weight, gguf_type=gguf_type, gguf_cls=gguf_cls, - gguf_real_shape=gguf_real_shape + gguf_real_shape=gguf_real_shape, + parent=parent_layer )) continue diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index c6c80c91..84083da9 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -89,6 +89,8 @@ class __Quant(ABC): grid_map: tuple[int | float, ...] = () grid_hex: bytes | None = None + computation_dtype: torch.dtype = torch.bfloat16 + def __init__(self): return TypeError("Quant conversion classes can't have instances") @@ -141,18 +143,29 @@ class __Quant(ABC): return blocks.reshape(original_shape) @classmethod - def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor: - # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) - block_size, type_size = GGML_QUANT_SIZES[cls.qtype] + def bake_layer(cls, layer, weight, computation_dtype): + data = weight.data + cls.computation_dtype = computation_dtype + cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) - n_blocks = rows.numel() // type_size - blocks = rows.reshape((n_blocks, type_size)) - blocks = cls.dequantize_blocks_pytorch(blocks, block_size, type_size) + n_blocks = rows.numel() // cls.type_size + blocks = rows.reshape((n_blocks, cls.type_size)) + weight.data = blocks + cls.bake_layer_weight(layer, weight) + return + + @classmethod + def bake_layer_weight(cls, layer, weight): + pass + + @classmethod + def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor: + blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent) return blocks.reshape(original_shape) @classmethod @abstractmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: raise NotImplementedError @classmethod @@ -289,15 +302,26 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): return (d * qs.astype(np.float32)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def bake_layer_weight(cls, layer, weight): + blocks = weight.data + d, x = quick_split(blocks, [2]) + d = d.view(torch.float16).to(cls.computation_dtype) + weight.data = x + layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + return + + @classmethod + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: n_blocks = blocks.shape[0] - d = blocks[:, :2].view(torch.float16) - qs = blocks[:, 2:] + d, qs = parent.quant_state_0, blocks + + if d.device != qs.device: + d = d.to(device=qs.device) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 - return d * qs + return (d * qs) @classmethod def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: @@ -358,12 +382,29 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): return (d * qs) + m @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def bake_layer_weight(cls, layer, weight): + blocks = weight.data + + d, m, qs = quick_split(blocks, [2, 2]) + d = d.view(torch.float16).to(cls.computation_dtype) + m = m.view(torch.float16).to(cls.computation_dtype) + + weight.data = qs + layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False) + return + + @classmethod + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: n_blocks = blocks.shape[0] - d = blocks[:, :2].view(torch.float16) - m = blocks[:, 2:4].view(torch.float16) - qs = blocks[:, 4:] + d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks + + if d.device != qs.device: + d = d.to(device=qs.device) + + if m.device != qs.device: + m = m.to(device=qs.device) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qs = (qs & 0x0F).reshape(n_blocks, -1) @@ -414,7 +455,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): return (d * qs.astype(np.float32)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: def to_uint32(x): # pytorch uint32 by City96 - Apache-2.0 x = x.view(torch.uint8).to(torch.int32) @@ -422,11 +463,8 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): n_blocks = blocks.shape[0] - d = blocks[:, :2] - qh = blocks[:, 2:6] - qs = blocks[:, 6:] - - d = d.view(torch.float16).to(torch.float32) + d, qh, qs = quick_split(blocks, [2, 4]) + d = d.view(torch.float16).to(cls.computation_dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) @@ -436,7 +474,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): ql = (ql & 0x0F).reshape(n_blocks, -1) qs = (ql | (qh << 4)).to(torch.int8) - 16 - return d * qs + return (d * qs) @classmethod def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: @@ -520,7 +558,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): return (d * qs) + m @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: def to_uint32(x): # pytorch uint32 by City96 - Apache-2.0 x = x.view(torch.uint8).to(torch.int32) @@ -528,11 +566,9 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): n_blocks = blocks.shape[0] - d = blocks[:, :2].view(torch.float16) - m = blocks[:, 2:4].view(torch.float16) - qh = blocks[:, 4:8] - qs = blocks[:, 8:] - + d, m, qh, qs = quick_split(blocks, [2, 2, 4]) + d = d.view(torch.float16).to(cls.computation_dtype) + m = m.view(torch.float16).to(cls.computation_dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) @@ -570,9 +606,22 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): return (x * d) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - d = blocks[:, :2].view(torch.float16) - x = blocks[:, 2:].view(torch.int8).to(torch.float16) + def bake_layer_weight(cls, layer, weight): + blocks = weight.data + d, x = quick_split(blocks, [2]) + d = d.view(torch.float16).to(cls.computation_dtype) + weight.data = x + layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + return + + @classmethod + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + x = blocks + d = parent.quant_state_0 + + if d.device != x.device: + d = d.to(device=x.device) + return x * d @classmethod @@ -613,12 +662,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): return qs.reshape((n_blocks, -1)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) n_blocks = blocks.shape[0] scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2]) - d = d.view(torch.float16) - dmin = dmin.view(torch.float16) + d = d.view(torch.float16).to(cls.computation_dtype) + dmin = dmin.view(torch.float16).to(cls.computation_dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) @@ -673,11 +722,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): return (dl * q).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) n_blocks = blocks.shape[0] hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12]) - d = d.view(torch.float16) + d = d.view(torch.float16).to(cls.computation_dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) lscales = lscales.reshape((n_blocks, 16)) @@ -754,14 +803,14 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): return (d * qs - dm).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) QK_K = 256 K_SCALE_SIZE = 12 n_blocks = blocks.shape[0] d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE]) - d = d.view(torch.float16) - dmin = dmin.view(torch.float16) + d = d.view(torch.float16).to(cls.computation_dtype) + dmin = dmin.view(torch.float16).to(cls.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) @@ -797,14 +846,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): return (d * q - dm).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) QK_K = 256 K_SCALE_SIZE = 12 n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8]) - d = d.view(torch.float16) - dmin = dmin.view(torch.float16) + d = d.view(torch.float16).to(cls.computation_dtype) + dmin = dmin.view(torch.float16).to(cls.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) @@ -839,12 +888,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): return (d * q).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: # Written by ChatGPT n_blocks = blocks.shape[0] ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16]) - scales = scales.view(torch.int8) - d = d.view(torch.float16) + scales = scales.view(torch.int8).to(cls.computation_dtype) + d = d.view(torch.float16).to(cls.computation_dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32))