diff --git a/backend/loader.py b/backend/loader.py index 734f8855..24dc4b40 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -104,11 +104,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale']) - if storage_dtype in ['gguf']: - from backend.operations_gguf import bake_gguf_model - model.computation_dtype = torch.float16 - model = bake_gguf_model(model) - return model if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']: assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!' @@ -167,10 +162,6 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p model.initial_device = initial_device model.offload_device = offload_device - if storage_dtype in ['gguf']: - from backend.operations_gguf import bake_gguf_model - model = bake_gguf_model(model) - return model print(f'Skipped: {component_name} = {lib_name}.{cls_name}') diff --git a/backend/memory_management.py b/backend/memory_management.py index 308f3f3a..c3d80edb 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -310,7 +310,7 @@ def state_dict_parameters(sd): def state_dict_dtype(state_dict): for k, v in state_dict.items(): - if hasattr(v, 'is_gguf'): + if hasattr(v, 'gguf_cls'): return 'gguf' if 'bitsandbytes__nf4' in k: return 'nf4' @@ -337,6 +337,19 @@ def state_dict_dtype(state_dict): return major_dtype +def bake_gguf_model(model): + if getattr(model, 'gguf_baked', False): + return + + for p in model.parameters(): + gguf_cls = getattr(p, 'gguf_cls', None) + if gguf_cls is not None: + gguf_cls.bake(p) + + model.gguf_baked = True + return model + + def module_size(module, exclude_device=None, return_split=False): module_mem = 0 weight_mem = 0 @@ -493,6 +506,8 @@ class LoadedModel: global signal_empty_cache signal_empty_cache = True + bake_gguf_model(self.real_model) + self.model.lora_loader.refresh(offload_device=self.model.offload_device) if is_intel_xpu() and not args.disable_ipex_hijack: @@ -642,7 +657,7 @@ def load_models_gpu(models, memory_required=0): inference_memory = minimum_inference_memory() estimated_remaining_memory = current_free_mem - model_memory - inference_memory - print(f"[Memory Management] Target: {loaded_model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") + print(f"[Memory Management] Target: {loaded_model.model.model.__class__.__name__}, Free GPU: {current_free_mem / (1024 * 1024):.2f} MB, Model Require: {model_memory / (1024 * 1024):.2f} MB, Inference Require: {inference_memory / (1024 * 1024):.2f} MB, Remaining: {estimated_remaining_memory / (1024 * 1024):.2f} MB, ", end="") if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM diff --git a/backend/operations.py b/backend/operations.py index 9666d40d..d843b678 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -395,20 +395,22 @@ class ForgeOperationsGGUF(ForgeOperations): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): if hasattr(self, 'dummy'): + computation_dtype = self.dummy.dtype + if computation_dtype not in [torch.float16, torch.bfloat16]: + # GGUF cast only supports 16bits otherwise super slow + computation_dtype = torch.float16 if prefix + 'weight' in state_dict: self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device) + self.weight.computation_dtype = computation_dtype if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device) + self.bias.computation_dtype = computation_dtype del self.dummy else: if prefix + 'weight' in state_dict: self.weight = state_dict[prefix + 'weight'] if prefix + 'bias' in state_dict: self.bias = state_dict[prefix + 'bias'] - if self.weight is not None and hasattr(self.weight, 'parent'): - self.weight.parent = self - if self.bias is not None and hasattr(self.bias, 'parent'): - self.bias.parent = self return def _apply(self, fn, recurse=True): diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 3d9cec66..f30ef7dd 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -19,77 +19,40 @@ quants_mapping = { class ParameterGGUF(torch.nn.Parameter): def __init__(self, tensor=None, requires_grad=False, no_init=False): super().__init__() - self.is_gguf = True - if no_init: return - self.gguf_type = tensor.tensor_type - self.gguf_real_shape = torch.Size(reversed(list(tensor.shape))) - self.gguf_cls = quants_mapping.get(self.gguf_type, None) - self.parent = None + self.gguf_cls = quants_mapping.get(tensor.tensor_type, None) + self.real_shape = torch.Size(reversed(list(tensor.shape))) + self.computation_dtype = torch.float16 + self.baked = False + return @property def shape(self): - return self.gguf_real_shape + return self.real_shape def __new__(cls, tensor=None, requires_grad=False, no_init=False): return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad) def dequantize_as_pytorch_parameter(self): - if self.parent is None: - self.parent = torch.nn.Module() - self.gguf_cls.bake_layer(self.parent, self, computation_dtype=torch.float16) + if self.gguf_cls is not None: + self.gguf_cls.bake(self) return torch.nn.Parameter(dequantize_tensor(self), requires_grad=False) - def to(self, *args, **kwargs): - new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True) - new.gguf_type = self.gguf_type - new.gguf_real_shape = self.gguf_real_shape + def copy_with_data(self, data): + new = ParameterGGUF(data, no_init=True) new.gguf_cls = self.gguf_cls - new.parent = self.parent + new.real_shape = self.real_shape + new.computation_dtype = self.computation_dtype + new.baked = self.baked return new + def to(self, *args, **kwargs): + return self.copy_with_data(self.data.to(*args, **kwargs)) + def pin_memory(self, device=None): - new = ParameterGGUF(torch.Tensor.pin_memory(self, device=device), no_init=True) - new.gguf_type = self.gguf_type - new.gguf_real_shape = self.gguf_real_shape - new.gguf_cls = self.gguf_cls - new.parent = self.parent - return new - - @classmethod - def make(cls, data, gguf_type, gguf_cls, gguf_real_shape, parent): - new = ParameterGGUF(data, no_init=True) - new.gguf_type = gguf_type - new.gguf_real_shape = gguf_real_shape - new.gguf_cls = gguf_cls - new.parent = parent - return new - - -def bake_gguf_model(model): - computation_dtype = model.computation_dtype - - if computation_dtype not in [torch.float16, torch.bfloat16]: - # Baking only supports 16bits otherwise super slow - computation_dtype = torch.float16 - - backed_layer_counter = 0 - - for m in model.modules(): - if hasattr(m, 'weight'): - weight = m.weight - if hasattr(weight, 'gguf_cls'): - gguf_cls = weight.gguf_cls - if gguf_cls is not None: - backed_layer_counter += 1 - gguf_cls.bake_layer(m, weight, computation_dtype) - - if backed_layer_counter > 0: - print(f'GGUF backed {backed_layer_counter} layers.') - - return model + return self.copy_with_data(torch.Tensor.pin_memory(self, device=device)) def dequantize_tensor(tensor): @@ -99,11 +62,9 @@ def dequantize_tensor(tensor): if not hasattr(tensor, 'gguf_cls'): return tensor - data = tensor gguf_cls = tensor.gguf_cls - gguf_real_shape = tensor.gguf_real_shape if gguf_cls is None: - return data + return tensor - return gguf_cls.dequantize_pytorch(data, gguf_real_shape) + return gguf_cls.dequantize_pytorch(tensor) diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index 6e0e05d7..a9c7e84d 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -387,13 +387,12 @@ class LoraLoader: from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) - gguf_cls, gguf_type, gguf_real_shape = None, None, None + gguf_cls = getattr(weight, 'gguf_cls', None) + gguf_parameter = None - if hasattr(weight, 'is_gguf'): + if gguf_cls is not None: + gguf_parameter = weight from backend.operations_gguf import dequantize_tensor - gguf_cls = weight.gguf_cls - gguf_type = weight.gguf_type - gguf_real_shape = weight.gguf_real_shape weight = dequantize_tensor(weight) try: @@ -409,17 +408,9 @@ class LoraLoader: continue if gguf_cls is not None: - from backend.operations_gguf import ParameterGGUF - weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape) - weight = ParameterGGUF.make( - data=weight, - gguf_type=gguf_type, - gguf_cls=gguf_cls, - gguf_real_shape=gguf_real_shape, - parent=parent_layer - ) - gguf_cls.bake_layer(parent_layer, weight, gguf_cls.computation_dtype) - utils.set_attr_raw(self.model, key, weight) + gguf_parameter.data = gguf_cls.quantize_pytorch(weight, gguf_parameter.shape) + gguf_parameter.baked = False + gguf_cls.bake(gguf_parameter) continue utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) diff --git a/backend/utils.py b/backend/utils.py index d023f577..7c7741db 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -157,9 +157,9 @@ def beautiful_print_gguf_state_dict_statics(state_dict): from gguf.constants import GGMLQuantizationType type_counts = {} for k, v in state_dict.items(): - gguf_type = getattr(v, 'gguf_type', None) - if gguf_type is not None: - type_name = GGMLQuantizationType(gguf_type).name + gguf_cls = getattr(v, 'gguf_cls', None) + if gguf_cls is not None: + type_name = gguf_cls.__name__ if type_name in type_counts: type_counts[type_name] += 1 else: diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index c0d144d5..1fac91ff 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -13,7 +13,7 @@ from .quick_4bits_ops import change_4bits_order, quick_unpack_4bits, quick_unpac import numpy as np -quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=1) +quick_split = lambda x, p: torch.split(x, p + [x.shape[1] - sum(p)], dim=-1) def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: @@ -90,8 +90,6 @@ class __Quant(ABC): grid_map: tuple[int | float, ...] = () grid_hex: bytes | None = None - computation_dtype: torch.dtype = torch.bfloat16 - def __init__(self): return TypeError("Quant conversion classes can't have instances") @@ -144,29 +142,35 @@ class __Quant(ABC): return blocks.reshape(original_shape) @classmethod - def bake_layer(cls, layer, weight, computation_dtype): - data = weight.data - cls.computation_dtype = computation_dtype + def bake(cls, parameter): + if parameter.baked: + return + + data = parameter.data cls.block_size, cls.type_size = GGML_QUANT_SIZES[cls.qtype] rows = data.reshape((-1, data.shape[-1])).view(torch.uint8) n_blocks = rows.numel() // cls.type_size blocks = rows.reshape((n_blocks, cls.type_size)) - weight.data = blocks - cls.bake_layer_weight(layer, weight) + parameter.data = blocks.contiguous() + cls.bake_inner(parameter) + parameter.baked = True return @classmethod - def bake_layer_weight(cls, layer, weight): + def bake_inner(cls, parameter): pass @classmethod - def dequantize_pytorch(cls, x, original_shape) -> torch.Tensor: - blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x.parent) - return blocks.reshape(original_shape) + def dequantize_pytorch(cls, x): + if not x.baked: + raise ValueError('GGUF Tensor is not baked!') + + blocks = cls.dequantize_blocks_pytorch(x.data, cls.block_size, cls.type_size, x) + return blocks.view(x.shape) @classmethod @abstractmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: raise NotImplementedError @classmethod @@ -303,22 +307,18 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): return (d * qs.astype(np.float32)) @classmethod - def bake_layer_weight(cls, layer, weight): - blocks = weight.data + def bake_inner(cls, parameter): + blocks = parameter.data d, x = quick_split(blocks, [2]) - d = d.view(torch.float16).to(cls.computation_dtype) - x = change_4bits_order(x) - weight.data = x - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) + x = change_4bits_order(x).view(torch.uint8) + parameter.data = torch.cat([d, x], dim=-1).contiguous() return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - d, qs = parent.quant_state_0, blocks - - if d.device != qs.device: - d = d.to(device=qs.device) - + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: + d, qs = quick_split(blocks, [2]) + d = d.view(parameter.computation_dtype) qs = quick_unpack_4bits(qs) return d * qs @@ -381,30 +381,23 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): return (d * qs) + m @classmethod - def bake_layer_weight(cls, layer, weight): - blocks = weight.data + def bake_inner(cls, parameter): + blocks = parameter.data d, m, qs = quick_split(blocks, [2, 2]) - d = d.view(torch.float16).to(cls.computation_dtype) - m = m.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) + m = m.view(torch.float16).to(parameter.computation_dtype).view(torch.uint8) + qs = change_4bits_order(qs).view(torch.uint8) - qs = change_4bits_order(qs) + parameter.data = torch.cat([d, m, qs], dim=-1).contiguous() - weight.data = qs - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) - layer.quant_state_1 = torch.nn.Parameter(m, requires_grad=False) return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - d, m, qs = parent.quant_state_0, parent.quant_state_1, blocks - - if d.device != qs.device: - d = d.to(device=qs.device) - - if m.device != qs.device: - m = m.to(device=qs.device) - + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: + d, m, qs = quick_split(blocks, [2, 2]) + d = d.view(parameter.computation_dtype) + m = m.view(parameter.computation_dtype) qs = quick_unpack_4bits_u(qs) return (d * qs) + m @@ -452,7 +445,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): return (d * qs.astype(np.float32)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: def to_uint32(x): # pytorch uint32 by City96 - Apache-2.0 x = x.view(torch.uint8).to(torch.int32) @@ -461,7 +454,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): n_blocks = blocks.shape[0] d, qh, qs = quick_split(blocks, [2, 4]) - d = d.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) @@ -555,7 +548,7 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): return (d * qs) + m @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: def to_uint32(x): # pytorch uint32 by City96 - Apache-2.0 x = x.view(torch.uint8).to(torch.int32) @@ -564,8 +557,8 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): n_blocks = blocks.shape[0] d, m, qh, qs = quick_split(blocks, [2, 2, 4]) - d = d.view(torch.float16).to(cls.computation_dtype) - m = m.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) + m = m.view(torch.float16).to(parameter.computation_dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) @@ -603,23 +596,18 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): return (x * d) @classmethod - def bake_layer_weight(cls, layer, weight): - blocks = weight.data + def bake_inner(cls, parameter): + blocks = parameter.data d, x = quick_split(blocks, [2]) x = x.view(torch.int8) - d = d.view(torch.float16).to(cls.computation_dtype) - weight.data = x - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) + d = d.view(torch.float16).to(parameter.computation_dtype).view(torch.int8) + parameter.data = torch.cat([d, x], dim=-1).contiguous() return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: - x = blocks - d = parent.quant_state_0 - - if d.device != x.device: - d = d.to(device=x.device) - + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: + d, x = quick_split(blocks, [2]) + d = d.view(parameter.computation_dtype) return x * d @classmethod @@ -660,12 +648,12 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): return qs.reshape((n_blocks, -1)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) n_blocks = blocks.shape[0] scales, qs, d, dmin = quick_split(blocks, [QK_K // 16, QK_K // 4, 2]) - d = d.view(torch.float16).to(cls.computation_dtype) - dmin = dmin.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) + dmin = dmin.view(torch.float16).to(parameter.computation_dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) @@ -720,11 +708,11 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): return (dl * q).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) n_blocks = blocks.shape[0] hmask, qs, scales, d = quick_split(blocks, [QK_K // 8, QK_K // 4, 12]) - d = d.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) lscales = lscales.reshape((n_blocks, 16)) @@ -801,42 +789,39 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): return (d * qs - dm).reshape((n_blocks, QK_K)) @classmethod - def bake_layer_weight(cls, layer, weight): # Only compute one time when model load + def bake_inner(cls, parameter): # Only compute one time when model load # Copyright Forge 2024 - blocks = weight.data - K_SCALE_SIZE = 12 + blocks = parameter.data n_blocks = blocks.shape[0] - d, dmin, scales, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE]) - d = d.view(torch.float16).to(cls.computation_dtype) - dmin = dmin.view(torch.float16).to(cls.computation_dtype) + d, dmin, scales, qs = quick_split(blocks, [2, 2, cls.K_SCALE_SIZE]) + d = d.view(torch.float16).to(parameter.computation_dtype) + dmin = dmin.view(torch.float16).to(parameter.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) - dm = (dmin * m).reshape((n_blocks, -1, 1)).to(cls.computation_dtype) + dm = (dmin * m).reshape((n_blocks, -1, 1)).to(parameter.computation_dtype) qs = qs.reshape((n_blocks, -1, 1, 32)) qs = change_4bits_order(qs) - weight.data = qs - layer.quant_state_0 = torch.nn.Parameter(d, requires_grad=False) - layer.quant_state_1 = torch.nn.Parameter(dm, requires_grad=False) + d = d.view(torch.uint8).reshape((n_blocks, -1)) + dm = dm.view(torch.uint8).reshape((n_blocks, -1)) + qs = qs.view(torch.uint8) + + parameter.data = torch.cat([d, dm, qs], dim=-1).contiguous() return @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # Compute in each diffusion iteration n_blocks = blocks.shape[0] - d, dm, qs = parent.quant_state_0, parent.quant_state_1, blocks + d, dm, qs = quick_split(blocks, [16, 16]) + d = d.view(parameter.computation_dtype).view((n_blocks, -1, 1)) + dm = dm.view(parameter.computation_dtype).view((n_blocks, -1, 1)) + qs = quick_unpack_4bits_u(qs).view((n_blocks, -1, 32)) - if d.device != qs.device: - d = d.to(device=qs.device) - - if dm.device != qs.device: - dm = dm.to(device=qs.device) - - qs = quick_unpack_4bits_u(qs).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) @@ -867,14 +852,14 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): return (d * q - dm).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) QK_K = 256 K_SCALE_SIZE = 12 n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = quick_split(blocks, [2, 2, K_SCALE_SIZE, QK_K // 8]) - d = d.view(torch.float16).to(cls.computation_dtype) - dmin = dmin.view(torch.float16).to(cls.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) + dmin = dmin.view(torch.float16).to(parameter.computation_dtype) sc, m = Q4_K.get_scale_min_pytorch(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) @@ -909,12 +894,12 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): return (d * q).reshape((n_blocks, QK_K)) @classmethod - def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parent) -> torch.Tensor: + def dequantize_blocks_pytorch(cls, blocks, block_size, type_size, parameter) -> torch.Tensor: # Written by ChatGPT n_blocks = blocks.shape[0] ql, qh, scales, d, = quick_split(blocks, [QK_K // 2, QK_K // 4, QK_K // 16]) - scales = scales.view(torch.int8).to(cls.computation_dtype) - d = d.view(torch.float16).to(cls.computation_dtype) + scales = scales.view(torch.int8).to(parameter.computation_dtype) + d = d.view(torch.float16).to(parameter.computation_dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32))