From d8b83a95017eddab4d4a1c008d85b3d3893eed80 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 14 Aug 2024 22:26:00 -0700 Subject: [PATCH] gguf preview --- backend/loader.py | 6 +- backend/memory_management.py | 4 +- backend/operations.py | 42 ++++++++++++- backend/operations_gguf.py | 109 ++++++++++++++++++++++++++++++++++ backend/patcher/controlnet.py | 2 +- backend/utils.py | 31 ++++++++++ modules/sd_models.py | 2 +- requirements_versions.txt | 1 + 8 files changed, 190 insertions(+), 7 deletions(-) create mode 100644 backend/operations_gguf.py diff --git a/backend/loader.py b/backend/loader.py index fa33ccbc..4aa59b42 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -117,17 +117,17 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p if unet_storage_dtype_overwrite is not None: storage_dtype = unet_storage_dtype_overwrite else: - if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4']: + if state_dict_dtype in [torch.float8_e4m3fn, torch.float8_e5m2, 'nf4', 'fp4', 'gguf']: print(f'Using Detected UNet Type: {state_dict_dtype}') storage_dtype = state_dict_dtype - if state_dict_dtype in ['nf4', 'fp4']: + if state_dict_dtype in ['nf4', 'fp4', 'gguf']: print(f'Using pre-quant state dict!') load_device = memory_management.get_torch_device() computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=guess.supported_inference_dtypes) offload_device = memory_management.unet_offload_device() - if storage_dtype in ['nf4', 'fp4']: + if storage_dtype in ['nf4', 'fp4', 'gguf']: initial_device = memory_management.unet_inital_load_device(parameters=state_dict_size, dtype=computation_dtype) with using_forge_operations(device=initial_device, dtype=computation_dtype, manual_cast_enabled=False, bnb_dtype=storage_dtype): model = model_loader(unet_config) diff --git a/backend/memory_management.py b/backend/memory_management.py index 4f64c61a..69f84de7 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -302,7 +302,9 @@ def state_dict_size(sd, exclude_device=None): def state_dict_dtype(state_dict): - for k in state_dict.keys(): + for k, v in state_dict.items(): + if hasattr(v, 'is_gguf'): + return 'gguf' if 'bitsandbytes__nf4' in k: return 'nf4' if 'bitsandbytes__fp4' in k: diff --git a/backend/operations.py b/backend/operations.py index 06a56688..636e9df4 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -330,6 +330,44 @@ except: bnb_avaliable = False +from backend.operations_gguf import functional_linear_gguf + + +class ForgeOperationsGGUF(ForgeOperations): + class Linear(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) + self.weight = None + self.bias = None + self.parameters_manual_cast = current_manual_cast_enabled + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + if hasattr(self, 'dummy'): + if prefix + 'weight' in state_dict: + self.weight = state_dict[prefix + 'weight'].to(device=self.dummy.device) + if prefix + 'bias' in state_dict: + self.bias = state_dict[prefix + 'bias'].to(device=self.dummy.device) + del self.dummy + else: + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def _apply(self, fn, recurse=True): + if self.weight is not None: + self.weight = fn(self.weight) + if self.bias is not None: + self.bias = fn(self.bias) + return super()._apply(fn, recurse=recurse) + + def forward(self, x): + if self.parameters_manual_cast: + weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True) + with main_stream_worker(weight, bias, signal): + return functional_linear_gguf(x, weight, bias) + else: + return functional_linear_gguf(x, self.weight, self.bias) + + @contextlib.contextmanager def using_forge_operations(operations=None, device=None, dtype=None, manual_cast_enabled=False, bnb_dtype=None): global current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype @@ -337,7 +375,9 @@ def using_forge_operations(operations=None, device=None, dtype=None, manual_cast current_device, current_dtype, current_manual_cast_enabled, current_bnb_dtype = device, dtype, manual_cast_enabled, bnb_dtype if operations is None: - if bnb_avaliable and bnb_dtype in ['nf4', 'fp4']: + if bnb_dtype in ['gguf']: + operations = ForgeOperationsGGUF + elif bnb_avaliable and bnb_dtype in ['nf4', 'fp4']: operations = ForgeOperationsBNB4bits else: operations = ForgeOperations diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py new file mode 100644 index 00000000..7a01ee86 --- /dev/null +++ b/backend/operations_gguf.py @@ -0,0 +1,109 @@ +import gguf +import torch + + +def functional_linear_gguf(x, weight, bias=None): + target_dtype = x.dtype + weight = dequantize_tensor(weight, target_dtype) + bias = dequantize_tensor(bias, target_dtype) + return torch.nn.functional.linear(x, weight, bias) + + +def dequantize_tensor(tensor, dtype=torch.float16): + # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + if tensor is None: + return None + + data = torch.tensor(tensor.data) + qtype = tensor.gguf_type + oshape = tensor.gguf_real_shape + + if qtype == gguf.GGMLQuantizationType.F32: + return data.to(dtype) + elif qtype == gguf.GGMLQuantizationType.F16: + return data.to(dtype) + elif qtype in dequantize_functions: + # this is the main pytorch op + return dequantize(data, qtype, oshape).to(dtype) + else: + # this is incredibly slow + new = gguf.quants.dequantize(data.cpu().numpy(), qtype) + return torch.from_numpy(new).to(data.device, dtype=dtype) + + +def dequantize(data, qtype, oshape): + # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + """ + Dequantize tensor back to usable shape/dtype + """ + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + dequantize_blocks = dequantize_functions[qtype] + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = rows.reshape((n_blocks, type_size)) + blocks = dequantize_blocks(blocks, block_size, type_size) + return blocks.reshape(oshape) + + +def to_uint32(x): + # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + # no uint32 :( + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def dequantize_blocks_Q8_0(blocks, block_size, type_size): + # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + d = blocks[:, :2].view(torch.float16) + x = blocks[:, 2:].view(torch.int8).to(torch.float16) + return (x * d) + + +def dequantize_blocks_Q5_0(blocks, block_size, type_size): + # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + n_blocks = blocks.shape[0] + + d = blocks[:, :2] + qh = blocks[:, 2:6] + qs = blocks[:, 6:] + + d = d.view(torch.float16).to(torch.float32) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return (d * qs) + + +def dequantize_blocks_Q4_0(blocks, block_size, type_size): + # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + n_blocks = blocks.shape[0] + + d = blocks[:, :2].view(torch.float16) + qs = blocks[:, 2:] + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return (d * qs) + + +dequantize_functions = { + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, +} diff --git a/backend/patcher/controlnet.py b/backend/patcher/controlnet.py index b21900b2..b0372d86 100644 --- a/backend/patcher/controlnet.py +++ b/backend/patcher/controlnet.py @@ -431,7 +431,7 @@ class ControlLora(ControlNet): dtype = model.storage_dtype - if dtype in ['nf4', 'fp4']: + if dtype in ['nf4', 'fp4', 'gguf']: dtype = torch.float16 controlnet_config["dtype"] = dtype diff --git a/backend/utils.py b/backend/utils.py index 44f9eab0..b1b5ee3f 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -1,3 +1,4 @@ +import gguf import torch import os import json @@ -5,6 +6,31 @@ import safetensors.torch import backend.misc.checkpoint_pickle +class ParameterGGUF(torch.nn.Parameter): + def __init__(self, tensor=None, requires_grad=False, no_init=False): + super().__init__() + self.is_gguf = True + + if no_init: + return + + self.gguf_type = tensor.tensor_type + self.gguf_real_shape = torch.Size(reversed(list(tensor.shape))) + + @property + def shape(self): + return self.gguf_real_shape + + def __new__(cls, tensor=None, requires_grad=False, no_init=False): + return super().__new__(cls, torch.tensor(tensor.data), requires_grad=requires_grad) + + def to(self, *args, **kwargs): + new = ParameterGGUF(self.data.to(*args, **kwargs), no_init=True) + new.gguf_type = self.gguf_type + new.gguf_real_shape = self.gguf_real_shape + return new + + def read_arbitrary_config(directory): config_path = os.path.join(directory, 'config.json') @@ -22,6 +48,11 @@ def load_torch_file(ckpt, safe_load=False, device=None): device = torch.device("cpu") if ckpt.lower().endswith(".safetensors"): sd = safetensors.torch.load_file(ckpt, device=device.type) + elif ckpt.lower().endswith(".gguf"): + reader = gguf.GGUFReader(ckpt) + sd = {} + for tensor in reader.tensors: + sd[str(tensor.name)] = ParameterGGUF(tensor) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: diff --git a/modules/sd_models.py b/modules/sd_models.py index 16c4d459..3756dadc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -173,7 +173,7 @@ def list_models(): else: model_url = "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticVisionV51_v51VAE.safetensors" - model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors", ".gguf"], download_name="realisticVisionV51_v51VAE.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) if os.path.exists(cmd_ckpt): checkpoint_info = CheckpointInfo(cmd_ckpt) diff --git a/requirements_versions.txt b/requirements_versions.txt index 7bb2a7bc..71af0bdf 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -37,3 +37,4 @@ basicsr==1.4.2 diffusers==0.29.2 gradio_rangeslider==0.0.6 tqdm==4.66.1 +gguf==0.9.1