diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 37f8b83c..5a01d1ca 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -2,34 +2,27 @@ import gguf import torch -quants_mapping = { - gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0, - gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0, - gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0, -} +# def functional_quantize_gguf(weight): +# gguf_cls = weight.gguf_cls +# gguf_cls.en def functional_linear_gguf(x, weight, bias=None): target_dtype = x.dtype - weight = dequantize_tensor(weight, target_dtype) - bias = dequantize_tensor(bias, target_dtype) + weight = dequantize_tensor(weight).to(target_dtype) + bias = dequantize_tensor(bias).to(target_dtype) return torch.nn.functional.linear(x, weight, bias) -def dequantize_tensor(tensor, target_dtype=torch.float16): +def dequantize_tensor(tensor): if tensor is None: return None data = torch.tensor(tensor.data) - gguf_type = tensor.gguf_type + gguf_cls = tensor.gguf_cls gguf_real_shape = tensor.gguf_real_shape - if gguf_type in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16]: - return data.to(target_dtype) + if gguf_cls is None: + return data - if gguf_type not in quants_mapping: - raise NotImplementedError(f'Quant type {gguf_type} not implemented!') - - quant_cls = quants_mapping.get(gguf_type) - - return quant_cls.dequantize_pytorch(data, gguf_real_shape).to(target_dtype) + return gguf_cls.dequantize_pytorch(data, gguf_real_shape) diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 361130e1..fdf14639 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -4,10 +4,12 @@ # are from Forge, implemented from scratch (after forge-v1.0.1), and may have # certain level of differences. +import time import torch import copy import inspect +from tqdm import tqdm from backend import memory_management, utils, operations from backend.patcher.lora import merge_lora_to_model_weight @@ -237,6 +239,8 @@ class ModelPatcher: return sd def forge_patch_model(self, target_device=None): + execution_start_time = time.perf_counter() + for k, item in self.object_patches.items(): old = utils.get_attr(self.model, k) @@ -245,13 +249,16 @@ class ModelPatcher: utils.set_attr_raw(self.model, k, item) - for key, current_patches in self.patches.items(): + for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs to Diffusion Model') if len(self.patches) > 0 else self.patches): try: weight = utils.get_attr(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") + weight_original_device = weight.device + lora_computation_device = weight.device + if key not in self.backup: self.backup[key] = weight.to(device=self.offload_device) @@ -262,8 +269,6 @@ class ModelPatcher: assert weight.module is not None, 'BNB bad weight without parent layer!' bnb_layer = weight.module if weight.bnb_quantized: - weight_original_device = weight.device - if target_device is not None: assert target_device.type == 'cuda', 'BNB Must use CUDA!' weight = weight.to(target_device) @@ -272,35 +277,56 @@ class ModelPatcher: from backend.operations_bnb import functional_dequantize_4bit weight = functional_dequantize_4bit(weight) - - if target_device is None: - weight = weight.to(device=weight_original_device) else: weight = weight.data + if target_device is None: + weight = weight.to(device=lora_computation_device, non_blocking=memory_management.device_supports_non_blocking(lora_computation_device)) + else: + weight = weight.to(device=target_device, non_blocking=memory_management.device_supports_non_blocking(target_device)) + + gguf_cls, gguf_type, gguf_real_shape = None, None, None + if hasattr(weight, 'is_gguf'): - raise NotImplementedError('LoRAs for GGUF model are under construction!') + from backend.operations_gguf import dequantize_tensor + gguf_cls = weight.gguf_cls + gguf_type = weight.gguf_type + gguf_real_shape = weight.gguf_real_shape + weight = dequantize_tensor(weight) weight_original_dtype = weight.dtype - to_args = dict(dtype=torch.float32) + weight = weight.to(dtype=torch.float32, non_blocking=memory_management.device_supports_non_blocking(weight.device)) + weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) - if target_device is not None: - to_args['device'] = target_device - to_args['non_blocking'] = memory_management.device_supports_non_blocking(target_device) - - weight = weight.to(**to_args) - out_weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + if target_device is None: + weight = weight.to(device=weight_original_device, non_blocking=memory_management.device_supports_non_blocking(weight_original_device)) if bnb_layer is not None: - bnb_layer.reload_weight(out_weight) + bnb_layer.reload_weight(weight) continue - utils.set_attr_raw(self.model, key, torch.nn.Parameter(out_weight, requires_grad=False)) + if gguf_cls is not None: + from backend.utils import ParameterGGUF + weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape) + utils.set_attr_raw(self.model, key, ParameterGGUF.make( + data=weight, + gguf_type=gguf_type, + gguf_cls=gguf_cls, + gguf_real_shape=gguf_real_shape + )) + continue + + utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) if target_device is not None: self.model.to(target_device) self.current_device = target_device + moving_time = time.perf_counter() - execution_start_time + + if moving_time > 0.1: + print(f'LoRA patching has taken {moving_time:.2f} seconds') + return self.model def forge_unpatch_model(self, target_device=None): diff --git a/backend/utils.py b/backend/utils.py index b1b5ee3f..7860eb08 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -6,6 +6,13 @@ import safetensors.torch import backend.misc.checkpoint_pickle +quants_mapping = { + gguf.GGMLQuantizationType.Q4_0: gguf.Q4_0, + gguf.GGMLQuantizationType.Q5_0: gguf.Q5_0, + gguf.GGMLQuantizationType.Q8_0: gguf.Q8_0, +} + + class ParameterGGUF(torch.nn.Parameter): def __init__(self, tensor=None, requires_grad=False, no_init=False): super().__init__() @@ -16,6 +23,7 @@ class ParameterGGUF(torch.nn.Parameter): self.gguf_type = tensor.tensor_type self.gguf_real_shape = torch.Size(reversed(list(tensor.shape))) + self.gguf_cls = quants_mapping.get(self.gguf_type, None) @property def shape(self): @@ -28,6 +36,15 @@ class ParameterGGUF(torch.nn.Parameter): new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True) new.gguf_type = self.gguf_type new.gguf_real_shape = self.gguf_real_shape + new.gguf_cls = self.gguf_cls + return new + + @classmethod + def make(cls, data, gguf_type, gguf_cls, gguf_real_shape): + new = ParameterGGUF(data, no_init=True) + new.gguf_type = gguf_type + new.gguf_real_shape = gguf_real_shape + new.gguf_cls = gguf_cls return new diff --git a/modules/ui_settings.py b/modules/ui_settings.py index f30fc31e..6614bddc 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -324,7 +324,7 @@ class UiSettings: ) def button_set_checkpoint_change(value, dummy): - return value, opts.dumpjson() + return value.split(' [')[0], opts.dumpjson() button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index 08ce9437..bc434438 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -125,8 +125,17 @@ class __Quant(ABC): cls.grid = grid.reshape((1, 1, *cls.grid_shape)) @classmethod - def quantize_pytorch(cls, data: torch.Tensor) -> torch.Tensor: - return cls.quantize_blocks_pytorch(data) + def quantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor: + # Copyright Forge 2024, AGPL V3 + CC-BY SA + + original_shape = [x for x in original_shape] + original_shape[-1] = -1 + original_shape = tuple(original_shape) + + block_size, type_size = GGML_QUANT_SIZES[cls.qtype] + blocks = data.reshape(-1, block_size) + blocks = cls.quantize_blocks_pytorch(blocks, block_size, type_size) + return blocks.reshape(original_shape) @classmethod def dequantize_pytorch(cls, data: torch.Tensor, original_shape) -> torch.Tensor: @@ -145,7 +154,7 @@ class __Quant(ABC): @classmethod @abstractmethod - def quantize_blocks_pytorch(cls, blocks) -> torch.Tensor: + def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: raise NotImplementedError @classmethod @@ -287,6 +296,27 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 return d * qs + @classmethod + def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + # Copyright Forge 2024, AGPL V3 + CC-BY SA + + n_blocks = blocks.shape[0] + + imax = torch.abs(blocks).argmax(dim=-1, keepdim=True) + max_vals = torch.gather(blocks, -1, imax) + + d = max_vals / -8 + id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1.0 / d) + + qs = torch.trunc((blocks * id) + 8.5).clip(0, 15).to(torch.uint8) + + qs = qs.reshape((n_blocks, 2, block_size // 2)) + qs = qs[:, 0, :] | (qs[:, 1, :] << 4) + + d = d.to(torch.float16).view(torch.uint8) + + return torch.cat([d, qs], dim=-1) + class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): @classmethod @@ -392,6 +422,42 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): qs = (ql | (qh << 4)).to(torch.int8) - 16 return d * qs + @classmethod + def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + # Copyright Forge 2024, AGPL V3 + CC-BY SA + + n_blocks = blocks.shape[0] + + imax = torch.abs(blocks).argmax(dim=-1, keepdim=True) + max_val = torch.gather(blocks, dim=-1, index=imax) + + d = max_val / -16 + id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1.0 / d) + + q = torch.trunc((blocks.float() * id.float()) + 16.5).clamp(0, 31).to(torch.uint8) + + qs = q.view(n_blocks, 2, block_size // 2) + qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4) + + qh = q.view(n_blocks, 32) + qh_packed = torch.zeros((n_blocks, 4), dtype=torch.uint8, device=qh.device) + + for i in range(4): + qh_packed[:, i] = ( + (qh[:, i * 8 + 0] >> 4) | + (qh[:, i * 8 + 1] >> 3 & 0x02) | + (qh[:, i * 8 + 2] >> 2 & 0x04) | + (qh[:, i * 8 + 3] >> 1 & 0x08) | + (qh[:, i * 8 + 4] << 0 & 0x10) | + (qh[:, i * 8 + 5] << 1 & 0x20) | + (qh[:, i * 8 + 6] << 2 & 0x40) | + (qh[:, i * 8 + 7] << 3 & 0x80) + ) + + d = d.to(torch.float16).view(torch.uint8) + + return torch.cat([d, qh_packed, qs], dim=-1) + class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): @classmethod @@ -469,6 +535,16 @@ class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): x = blocks[:, 2:].view(torch.int8).to(torch.float16) return x * d + @classmethod + def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: + # Copyright Forge 2024, AGPL V3 + CC-BY SA + d = torch.abs(blocks).max(dim=1, keepdim=True).values / 127 + ids = torch.where(d == 0, torch.zeros_like(d), 1 / d) + qs = torch.round(blocks * ids) + d = d.to(torch.float16).view(torch.uint8) + qs = qs.to(torch.int8).view(torch.uint8) + return torch.cat([d, qs], dim=1) + class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): @classmethod