From 4c9380c46ab9e046404ed2d068c6132e90661fbe Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 30 Aug 2024 00:49:05 -0700 Subject: [PATCH] Speed up quant model loading and inference ... MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ... based on 3 evidences: 1. torch.Tensor.view on one big tensor is slightly faster than calling torch.Tensor.to on multiple small tensors. 2. but torch.Tensor.to with dtype change is significantly slower than torch.Tensor.view 3. “baking” model on GPU is significantly faster than computing on CPU when model load. mainly influence inference of Q8_0, Q4_0/1/K and loading of all quants --- backend/loader.py | 9 -- backend/memory_management.py | 19 +++- backend/operations.py | 10 +- backend/operations_gguf.py | 77 ++++----------- backend/patcher/lora.py | 23 ++--- backend/utils.py | 6 +- packages_3rdparty/gguf/quants.py | 163 ++++++++++++++----------------- 7 files changed, 126 insertions(+), 181 deletions(-) diff --git a/backend/loader.py b/backend/loader.py index 734f8855..24dc4b40 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -104,11 +104,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale']) - if storage_dtype in ['gguf']: - from backend.operations_gguf import bake_gguf_model - model.computation_dtype = torch.float16 - model = bake_gguf_model(model) - return model if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']: assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!' @@ -167,10 +162,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p model.initial_device = initial_device model.offload_device = offload_device - if storage_dtype in ['gguf']: - from backend.operations_gguf import bake_gguf_model - model = bake_gguf_model(model) - return model print(f'Skipped: {component_name} = {lib_name}.{cls_name}') diff --git a/backend/memory_management.py b/backend/memory_management.py index 308f3f3a..c3d80edb 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -310,7 +310,7 @@ def state_dict_parameters(sd): def state_dict_dtype(state_dict): for k, v in state_dict.items(): - if hasattr(v, 'is_gguf'): + if hasattr(v, 'gguf_cls'): return 'gguf' if 'bitsandbytes__nf4' in k: return 'nf4' @@ -337,6 +337,19 @@ def state_dict_dtype(state_dict): return major_dtype +def bake_gguf_model(model): + if getattr(model, 'gguf_baked', False): + return + + for p in model.parameters(): + gguf_cls = getattr(p, 'gguf_cls', None) + if gguf_cls is not None: + gguf_cls.bake(p) + + model.gguf_baked = True + return model + + def module_size(module, exclude_device=None, return_split=False): module_mem = 0 weight_mem = 0 @@ -493,6 +506,8 @@ class LoadedModel: global signal_empty_cache signal_empty_cache = True + bake_gguf_model(self.real_model) + self.model.lora_loader.refresh(offload_device=self.model.offload_device) if is_intel_xpu() and not args.disable_ipex_hijack: @@ -642,7 +657,7 @@ def load_models_gpu(models, memory_required=0): inference_memory = minimum_inference_memory() estimated_remaining_memory = current_free_mem - model_memory - inference_memory - print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") + print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM diff --git a/backend/operations.py b/backend/operations.py index 9666d40d..d843b678 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -395,20 +395,22 @@ class ForgeOperationsGGUF(ForgeOperations): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if hasattr(self, 'dummy'): + computation_dtype = self.dummy.dtype + if computation_dtype not in [torch.float16, torch.bfloat16]: + # GGUF cast only supports 16bits otherwise super slow + computation_dtype = torch.float16 if prefix + 'weight' in state_dict: self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device) + self.weight.computation_dtype = computation_dtype if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device) + self.bias.computation_dtype = computation_dtype del self.dummy else: if prefix + 'weight' in state_dict: self.weight = state_dict[prefix + 'weight'] if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'] - if self.weight is not None and hasattr(self.weight, 'parent'): - self.weight.parent = self - if self.bias is not None and hasattr(self.bias, 'parent'): - self.bias.parent = self return def _apply(self, fn, recurse=True): diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 3d9cec66..f30ef7dd 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -19,77 +19,40 @@ quants_mapping = { class ParameterGGUF(torch.nn.Parameter): def __init__(self, tensor=None, requires_grad=False, no_init=False): super().__init__() - self.is_gguf = True - if no_init: return - self.gguf_type = tensor.tensor_type - self.gguf_real_shape = torch.Size(reversed(list(tensor.shape))) - self.gguf_cls = quants_mapping.get(self.gguf_type, None) - self.parent = None + self.gguf_cls = quants_mapping.get(tensor.tensor_type, None) + self.real_shape = torch.Size(reversed(list(tensor.shape))) + self.computation_dtype = torch.float16 + self.baked = False + return @property def shape(self): - return self.gguf_real_shape + return self.real_shape def __new__(cls, tensor=None, requires_grad=False, no_init=False): return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad) def dequantize_as_pytorch_parameter(self): - if self.parent is None: - self.parent = torch.nn.Module() - self.gguf_cls.bake_layer(self.parent, self, computation_dtype=torch.float16) + if self.gguf_cls is not None: + self.gguf_cls.bake(self) return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False) - def to(self, *args, **kwargs): - new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True) - new.gguf_type = self.gguf_type - new.gguf_real_shape = self.gguf_real_shape + def copy_with_data(self, data): + new = ParameterGGUF(data, no_init=True) new.gguf_cls = self.gguf_cls - new.parent = self.parent + new.real_shape = self.real_shape + new.computation_dtype = self.computation_dtype + new.baked = self.baked return new + def to(self, *args, **kwargs): + return self.copy_with_data(self.data.to(*args, **kwargs)) + def pin_memory(self, device=None): - new = ParameterGGUF(torch.Tensor.pin_memory(self, device=device), no_init=True) - new.gguf_type = self.gguf_type - new.gguf_real_shape = self.gguf_real_shape - new.gguf_cls = self.gguf_cls - new.parent = self.parent - return new - - @classmethod - def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent): - new = ParameterGGUF(data, no_init=True) - new.gguf_type = gguf_type - new.gguf_real_shape = gguf_real_shape - new.gguf_cls = gguf_cls - new.parent = parent - return new - - -def bake_gguf_model(model): - computation_dtype = model.computation_dtype - - if computation_dtype not in [torch.float16, torch.bfloat16]: - # Baking only supports 16bits otherwise super slow - computation_dtype = torch.float16 - - backed_layer_counter = 0 - - for m in model.modules(): - if hasattr(m, 'weight'): - weight = m.weight - if hasattr(weight, 'gguf_cls'): - gguf_cls = weight.gguf_cls - if gguf_cls is not None: - backed_layer_counter += 1 - gguf_cls.bake_layer(m, weight, computation_dtype) - - if backed_layer_counter > 0: - print(f'GGUF backed {backed_layer_counter} layers.') - - return model + return self.copy_with_data(torch.Tensor.pin_memory(self, device=device)) def dequantize_tensor(tensor): @@ -99,11 +62,9 @@ def dequantize_tensor(tensor): if not hasattr(tensor, 'gguf_cls'): return tensor - data = tensor gguf_cls = tensor.gguf_cls - gguf_real_shape = tensor.gguf_real_shape if gguf_cls is None: - return data + return tensor - return gguf_cls.dequantize_pytorch(data, gguf_real_shape) + return gguf_cls.dequantize_pytorch(tensor) diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 6e0e05d7..a9c7e84d 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -387,13 +387,12 @@ class LoraLoader: from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) - gguf_cls, gguf_type, gguf_real_shape = None, None, None + gguf_cls = getattr(weight, 'gguf_cls', None) + gguf_parameter = None - if hasattr(weight, 'is_gguf'): + if gguf_cls is not None: + gguf_parameter = weight from backend.operations_gguf import dequantize_tensor - gguf_cls = weight.gguf_cls - gguf_type = weight.gguf_type - gguf_real_shape = weight.gguf_real_shape weight = dequantize_tensor(weight) try: @@ -409,17 +408,9 @@ class LoraLoader: continue if gguf_cls is not None: - from backend.operations_gguf import ParameterGGUF - weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape) - weight = ParameterGGUF.make( - data=weight, - gguf_type=gguf_type, - gguf_cls=gguf_cls, - gguf_real_shape=gguf_real_shape, - parent=parent_layer - ) - gguf_cls.bake_layer(parent_layer, weight, gguf_cls.computation_dtype) - utils.set_attr_raw(self.model, key, weight) + gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape) + gguf_parameter.baked = False + gguf_cls.bake(gguf_parameter) continue utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) diff --git a/backend/utils.py b/backend/utils.py index d023f577..7c7741db 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -157,9 +157,9 @@ def beautiful_print_gguf_state_dict_statics(state_dict): from gguf.constants import GGMLQuantizationType type_counts = {} for k, v in state_dict.items(): - gguf_type = getattr(v, 'gguf_type', None) - if gguf_type is not None: - type_name = GGMLQuantizationType(gguf_type).name + gguf_cls = getattr(v, 'gguf_cls', None) + if gguf_cls is not None: + type_name = gguf_cls.__name__ if type_name in type_counts: type_counts[type_name] += 1 else: diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index c0d144d5..1fac91ff 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -13,7 +13,7 @@ from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpac import numpy as np -quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=1) +quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=-1) def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: @@ -90,8 +90,6 @@ class __Quant(ABC): grid_map: tuple[int | float, ...] = () grid_hex: bytes | None = None - computation_dtype: torch.dtype = torch.bfloat16 - def __init__(self): return TypeError("Quant conversion classes can't have instances") @@ -144,29 +142,35 @@ class __Quant(ABC): return blocks.reshape(original_shape) @classmethod - def bake_layer(cls, layer, weight, computation_dtype): - data = weight.data - cls.computation_dtype = computation_dtype + def bake(cls, parameter): + if parameter.baked: + return + + data = parameter.data cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) n_blocks = rows.numel() // cls.type_size blocks = rows.reshape((n_blocks, cls.type_size)) - weight.data = blocks - cls.bake_layer_weight(layer, weight) + parameter.data = blocks.contiguous() + cls.bake_inner(parameter) + parameter.baked = True return @classmethod - def bake_layer_weight(cls, layer, weight): + def bake_inner(cls, parameter): pass @classmethod - def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor: - blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent) - return blocks.reshape(original_shape) + def dequantize_pytorch(cls, x): + if not x.baked: + raise ValueError('GGUF Tensor is not baked!') + + blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x) + return blocks.view(x.shape) @classmethod @abstractmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: raise NotImplementedError @classmethod @@ -303,22 +307,18 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): return (d * qs.astype(np.float32)) @classmethod - def bake_layer_weight(cls, layer, weight): - blocks = weight.data + def bake_inner(cls, parameter): + blocks = parameter.data d, x = quick_split(blocks, [2]) - d = d.view(torch.float16).to(cls.computation_dtype) - x = change_4bits_order(x) - weight.data = x - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) + x = change_4bits_order(x).view(torch.uint8) + parameter.data = torch.cat([d, x], dim=-1).contiguous() return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - d, qs = parent.quant_state_0, blocks - - if d.device != qs.device: - d = d.to(device=qs.device) - + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: + d, qs = quick_split(blocks, [2]) + d = d.view(parameter.computation_dtype) qs = quick_unpack_4bits(qs) return d * qs @@ -381,30 +381,23 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): return (d * qs) + m @classmethod - def bake_layer_weight(cls, layer, weight): - blocks = weight.data + def bake_inner(cls, parameter): + blocks = parameter.data d, m, qs = quick_split(blocks, [2, 2]) - d = d.view(torch.float16).to(cls.computation_dtype) - m = m.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) + m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) + qs = change_4bits_order(qs).view(torch.uint8) - qs = change_4bits_order(qs) + parameter.data = torch.cat([d, m, qs], dim=-1).contiguous() - weight.data = qs - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) - layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False) return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks - - if d.device != qs.device: - d = d.to(device=qs.device) - - if m.device != qs.device: - m = m.to(device=qs.device) - + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: + d, m, qs = quick_split(blocks, [2, 2]) + d = d.view(parameter.computation_dtype) + m = m.view(parameter.computation_dtype) qs = quick_unpack_4bits_u(qs) return (d * qs) + m @@ -452,7 +445,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): return (d * qs.astype(np.float32)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: def to_uint32(x): # pytorch uint32 by City96 - Apache-2.0 x = x.view(torch.uint8).to(torch.int32) @@ -461,7 +454,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): n_blocks = blocks.shape[0] d, qh, qs = quick_split(blocks, [2, 4]) - d = d.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) @@ -555,7 +548,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): return (d * qs) + m @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: def to_uint32(x): # pytorch uint32 by City96 - Apache-2.0 x = x.view(torch.uint8).to(torch.int32) @@ -564,8 +557,8 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): n_blocks = blocks.shape[0] d, m, qh, qs = quick_split(blocks, [2, 2, 4]) - d = d.view(torch.float16).to(cls.computation_dtype) - m = m.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) + m = m.view(torch.float16).to(parameter.computation_dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) @@ -603,23 +596,18 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): return (x * d) @classmethod - def bake_layer_weight(cls, layer, weight): - blocks = weight.data + def bake_inner(cls, parameter): + blocks = parameter.data d, x = quick_split(blocks, [2]) x = x.view(torch.int8) - d = d.view(torch.float16).to(cls.computation_dtype) - weight.data = x - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8) + parameter.data = torch.cat([d, x], dim=-1).contiguous() return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - x = blocks - d = parent.quant_state_0 - - if d.device != x.device: - d = d.to(device=x.device) - + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: + d, x = quick_split(blocks, [2]) + d = d.view(parameter.computation_dtype) return x * d @classmethod @@ -660,12 +648,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): return qs.reshape((n_blocks, -1)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) n_blocks = blocks.shape[0] scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2]) - d = d.view(torch.float16).to(cls.computation_dtype) - dmin = dmin.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) + dmin = dmin.view(torch.float16).to(parameter.computation_dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) @@ -720,11 +708,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): return (dl * q).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) n_blocks = blocks.shape[0] hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12]) - d = d.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) lscales = lscales.reshape((n_blocks, 16)) @@ -801,42 +789,39 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): return (d * qs - dm).reshape((n_blocks, QK_K)) @classmethod - def bake_layer_weight(cls, layer, weight): # Only compute one time when model load + def bake_inner(cls, parameter): # Only compute one time when model load # Copyright Forge 2024 - blocks = weight.data - K_SCALE_SIZE = 12 + blocks = parameter.data n_blocks = blocks.shape[0] - d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE]) - d = d.view(torch.float16).to(cls.computation_dtype) - dmin = dmin.view(torch.float16).to(cls.computation_dtype) + d, dmin, scales, qs = quick_split(blocks, [2, 2, cls.K_SCALE_SIZE]) + d = d.view(torch.float16).to(parameter.computation_dtype) + dmin = dmin.view(torch.float16).to(parameter.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) - dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype) + dm = (dmin * m).reshape((n_blocks, -1, 1)).to(parameter.computation_dtype) qs = qs.reshape((n_blocks, -1, 1, 32)) qs = change_4bits_order(qs) - weight.data = qs - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) - layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False) + d = d.view(torch.uint8).reshape((n_blocks, -1)) + dm = dm.view(torch.uint8).reshape((n_blocks, -1)) + qs = qs.view(torch.uint8) + + parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous() return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # Compute in each diffusion iteration n_blocks = blocks.shape[0] - d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks + d, dm, qs = quick_split(blocks, [16, 16]) + d = d.view(parameter.computation_dtype).view((n_blocks, -1, 1)) + dm = dm.view(parameter.computation_dtype).view((n_blocks, -1, 1)) + qs = quick_unpack_4bits_u(qs).view((n_blocks, -1, 32)) - if d.device != qs.device: - d = d.to(device=qs.device) - - if dm.device != qs.device: - dm = dm.to(device=qs.device) - - qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) @@ -867,14 +852,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): return (d * q - dm).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) QK_K = 256 K_SCALE_SIZE = 12 n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8]) - d = d.view(torch.float16).to(cls.computation_dtype) - dmin = dmin.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) + dmin = dmin.view(torch.float16).to(parameter.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) @@ -909,12 +894,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): return (d * q).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # Written by ChatGPT n_blocks = blocks.shape[0] ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16]) - scales = scales.view(torch.int8).to(cls.computation_dtype) - d = d.view(torch.float16).to(cls.computation_dtype) + scales = scales.view(torch.int8).to(parameter.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32))