From d38e560e42b20f1f34b985187adbd1cde58bb15a Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Mon, 19 Aug 2024 04:31:00 -0700 Subject: [PATCH] Implement some rethinking about LoRA system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously. --- backend/memory_management.py | 1 + backend/operations.py | 91 +++++++++++++++------ backend/operations_gguf.py | 7 -- backend/patcher/base.py | 9 +- backend/patcher/lora.py | 113 +++++++++++++++----------- backend/sampling/sampling_function.py | 11 +++ backend/utils.py | 32 +++++++- launch.py | 3 + modules/processing.py | 1 + modules_forge/main_entry.py | 20 +++-- packages_3rdparty/gguf/quants.py | 71 +--------------- 11 files changed, 200 insertions(+), 159 deletions(-) diff --git a/backend/memory_management.py b/backend/memory_management.py index c6ff6b95..de56e306 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -378,6 +378,7 @@ class LoadedModel: try: self.real_model = self.model.forge_patch_model(patch_model_to) + self.model.current_device = self.model.load_device except Exception as e: self.model.forge_unpatch_model(self.model.offload_device) self.model_unload() diff --git a/backend/operations.py b/backend/operations.py index 1d79a4c3..da6903cb 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -5,12 +5,53 @@ import torch import contextlib from backend import stream, memory_management, utils +from backend.patcher.lora import merge_lora_to_weight stash = {} -def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False): +def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None): + patches = getattr(layer, 'forge_online_loras', None) + weight_patches, bias_patches = None, None + + if patches is not None: + weight_patches = patches.get('weight', None) + + if patches is not None: + bias_patches = patches.get('bias', None) + + weight = None + if layer.weight is not None: + weight = layer.weight + if weight_fn is not None: + if weight_args is not None: + fn_device = weight_args.get('device', None) + if fn_device is not None: + weight = weight.to(device=fn_device) + weight = weight_fn(weight) + if weight_args is not None: + weight = weight.to(**weight_args) + if weight_patches is not None: + weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype) + + bias = None + if layer.bias is not None: + bias = layer.bias + if bias_fn is not None: + if bias_args is not None: + fn_device = bias_args.get('device', None) + if fn_device is not None: + bias = bias.to(device=fn_device) + bias = bias_fn(bias) + if bias_args is not None: + bias = bias.to(**bias_args) + if bias_patches is not None: + bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype) + return weight, bias + + +def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None): weight, bias, signal = None, None, None non_blocking = True @@ -32,16 +73,10 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False if stream.should_use_stream(): with stream.stream_context()(stream.mover_stream): - if layer.weight is not None: - weight = layer.weight.to(**weight_args) - if layer.bias is not None: - bias = layer.bias.to(**bias_args) + weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) signal = stream.mover_stream.record_event() else: - if layer.weight is not None: - weight = layer.weight.to(**weight_args) - if layer.bias is not None: - bias = layer.bias.to(**bias_args) + weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) return weight, bias, signal @@ -109,7 +144,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(x, weight, bias) else: - return torch.nn.functional.linear(x, self.weight, self.bias) + weight, bias = get_weight_and_bias(self) + return torch.nn.functional.linear(x, weight, bias) class Conv2d(torch.nn.Conv2d): @@ -128,7 +164,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: - return super().forward(x) + weight, bias = get_weight_and_bias(self) + return super()._conv_forward(x, weight, bias) class Conv3d(torch.nn.Conv3d): @@ -147,7 +184,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: - return super().forward(x) + weight, bias = get_weight_and_bias(self) + return super()._conv_forward(input, weight, bias) class Conv1d(torch.nn.Conv1d): @@ -166,7 +204,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: - return super().forward(x) + weight, bias = get_weight_and_bias(self) + return super()._conv_forward(input, weight, bias) class ConvTranspose2d(torch.nn.ConvTranspose2d): @@ -188,7 +227,10 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: - return super().forward(x, output_size) + weight, bias = get_weight_and_bias(self) + num_spatial_dims = 2 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose1d(torch.nn.ConvTranspose1d): @@ -210,7 +252,10 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: - return super().forward(x, output_size) + weight, bias = get_weight_and_bias(self) + num_spatial_dims = 1 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose3d(torch.nn.ConvTranspose3d): @@ -232,7 +277,10 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: - return super().forward(x, output_size) + weight, bias = get_weight_and_bias(self) + num_spatial_dims = 3 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class GroupNorm(torch.nn.GroupNorm): @@ -328,7 +376,7 @@ except: bnb_avaliable = False -from backend.operations_gguf import functional_linear_gguf +from backend.operations_gguf import dequantize_tensor class ForgeOperationsGGUF(ForgeOperations): @@ -361,12 +409,9 @@ class ForgeOperationsGGUF(ForgeOperations): return self def forward(self, x): - if self.parameters_manual_cast: - weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True) - with main_stream_worker(weight, bias, signal): - return functional_linear_gguf(x, weight, bias) - else: - return functional_linear_gguf(x, self.weight, self.bias) + weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=dequantize_tensor) + with main_stream_worker(weight, bias, signal): + return torch.nn.functional.linear(x, weight, bias) @contextlib.contextmanager diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 495e6396..80b77382 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -58,13 +58,6 @@ class ParameterGGUF(torch.nn.Parameter): return new -def functional_linear_gguf(x, weight, bias=None): - target_dtype = x.dtype - weight = dequantize_tensor(weight).to(target_dtype) - bias = dequantize_tensor(bias).to(target_dtype) - return torch.nn.functional.linear(x, weight, bias) - - def dequantize_tensor(tensor): if tensor is None: return None diff --git a/backend/patcher/base.py b/backend/patcher/base.py index f60e60b7..39c28c60 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -4,14 +4,12 @@ # are from Forge, implemented from scratch (after forge-v1.0.1), and may have # certain level of differences. -import time -import torch + import copy import inspect -from tqdm import tqdm -from backend import memory_management, utils, operations -from backend.patcher.lora import merge_lora_to_model_weight, LoraLoader +from backend import memory_management, utils +from backend.patcher.lora import LoraLoader def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): @@ -229,7 +227,6 @@ class ModelPatcher: if target_device is not None: self.model.to(target_device) - self.current_device = target_device return self.model diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index d19556e2..01eb6930 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -5,7 +5,8 @@ import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui from tqdm import tqdm -from backend import memory_management, utils, operations +from backend import memory_management, utils +from backend.args import dynamic_args class ForgeLoraCollection: @@ -39,8 +40,10 @@ def model_lora_keys_unet(model, key_map={}): return get_function('model_lora_keys_unet')(model, key_map) -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): - dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32) +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype): + # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33 + + dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype) lora_diff *= alpha weight_calc = weight + lora_diff.type(weight.dtype) weight_norm = ( @@ -60,7 +63,12 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): return weight -def merge_lora_to_model_weight(patches, weight, key): +def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32): + # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446 + + weight_original_dtype = weight.dtype + weight = weight.to(dtype=computation_dtype) + for p in patches: strength = p[0] v = p[1] @@ -79,7 +87,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight *= strength_model if isinstance(v, list): - v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),) + v = (merge_lora_to_weight(v[1:], v[0].clone(), key),) patch_type = '' @@ -107,8 +115,8 @@ def merge_lora_to_model_weight(patches, weight, key): else: weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": - mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32) - mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32) + mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype) + mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype) dora_scale = v[4] if v[2] is not None: alpha = v[2] / mat2.shape[0] @@ -116,13 +124,13 @@ def merge_lora_to_model_weight(patches, weight, key): alpha = 1.0 if v[3] is not None: - mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32) + mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -141,23 +149,23 @@ def merge_lora_to_model_weight(patches, weight, key): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32), - memory_management.cast_to_device(w1_b, weight.device, torch.float32)) + w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype), + memory_management.cast_to_device(w1_b, weight.device, computation_dtype)) else: - w1 = memory_management.cast_to_device(w1, weight.device, torch.float32) + w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32), - memory_management.cast_to_device(w2_b, weight.device, torch.float32)) + w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype), + memory_management.cast_to_device(w2_b, weight.device, computation_dtype)) else: w2 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t2, weight.device, torch.float32), - memory_management.cast_to_device(w2_b, weight.device, torch.float32), - memory_management.cast_to_device(w2_a, weight.device, torch.float32)) + memory_management.cast_to_device(t2, weight.device, computation_dtype), + memory_management.cast_to_device(w2_b, weight.device, computation_dtype), + memory_management.cast_to_device(w2_a, weight.device, computation_dtype)) else: - w2 = memory_management.cast_to_device(w2, weight.device, torch.float32) + w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -169,7 +177,7 @@ def merge_lora_to_model_weight(patches, weight, key): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -190,24 +198,24 @@ def merge_lora_to_model_weight(patches, weight, key): t1 = v[5] t2 = v[6] m1 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t1, weight.device, torch.float32), - memory_management.cast_to_device(w1b, weight.device, torch.float32), - memory_management.cast_to_device(w1a, weight.device, torch.float32)) + memory_management.cast_to_device(t1, weight.device, computation_dtype), + memory_management.cast_to_device(w1b, weight.device, computation_dtype), + memory_management.cast_to_device(w1a, weight.device, computation_dtype)) m2 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t2, weight.device, torch.float32), - memory_management.cast_to_device(w2b, weight.device, torch.float32), - memory_management.cast_to_device(w2a, weight.device, torch.float32)) + memory_management.cast_to_device(t2, weight.device, computation_dtype), + memory_management.cast_to_device(w2b, weight.device, computation_dtype), + memory_management.cast_to_device(w2a, weight.device, computation_dtype)) else: - m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32), - memory_management.cast_to_device(w1b, weight.device, torch.float32)) - m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32), - memory_management.cast_to_device(w2b, weight.device, torch.float32)) + m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype), + memory_management.cast_to_device(w1b, weight.device, computation_dtype)) + m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype), + memory_management.cast_to_device(w2b, weight.device, computation_dtype)) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -221,15 +229,15 @@ def merge_lora_to_model_weight(patches, weight, key): dora_scale = v[5] - a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) - a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) - b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) - b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype) + a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype) + b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype) + b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype) try: lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -243,14 +251,19 @@ def merge_lora_to_model_weight(patches, weight, key): if old_weight is not None: weight = old_weight + weight = weight.to(dtype=weight_original_dtype) return weight +from backend import operations + + class LoraLoader: def __init__(self, model): self.model = model self.patches = {} self.backup = {} + self.online_backup = [] self.dirty = False def clear_patches(self): @@ -277,7 +290,7 @@ class LoraLoader: if key in model_sd: p.add(k) current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + current_patches.append([strength_patch, patches[k], strength_model, offset, function]) self.patches[key] = current_patches self.dirty = True @@ -293,10 +306,12 @@ class LoraLoader: # Restore - for k, w in self.backup.items(): - if target_device is not None: - w = w.to(device=target_device) + for m in set(self.online_backup): + del m.forge_online_loras + self.online_backup = [] + + for k, w in self.backup.items(): if not isinstance(w, torch.nn.Parameter): # In very few cases w = torch.nn.Parameter(w, requires_grad=False) @@ -305,11 +320,13 @@ class LoraLoader: self.backup = {} + online_mode = dynamic_args.get('online_lora', False) + # Patch for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches): try: - parent_layer, weight = utils.get_attr_with_parent(self.model, key) + parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") @@ -317,6 +334,14 @@ class LoraLoader: if key not in self.backup: self.backup[key] = weight.to(device=offload_device) + if online_mode: + if not hasattr(parent_layer, 'forge_online_loras'): + parent_layer.forge_online_loras = {} + + parent_layer.forge_online_loras[child_key] = current_patches + self.online_backup.append(parent_layer) + continue + bnb_layer = None if operations.bnb_avaliable: @@ -357,17 +382,13 @@ class LoraLoader: gguf_real_shape = weight.gguf_real_shape weight = dequantize_tensor(weight) - weight_original_dtype = weight.dtype - try: - weight = weight.to(dtype=torch.float32) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) except: print('Patching LoRA weights failed. Retrying by offloading models.') self.model.to(device=offload_device) memory_management.soft_empty_cache() - weight = weight.to(dtype=torch.float32) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) if bnb_layer is not None: bnb_layer.reload_weight(weight) diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index c51ee742..a9fbec0c 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -10,6 +10,8 @@ import collections from backend import memory_management from backend.sampling.condition import Condition, compile_conditions, compile_weighted_conditions from backend.operations import cleanup_cache +from backend.args import dynamic_args +from backend import utils def get_area_and_mult(conds, x_in, timestep_in): @@ -353,10 +355,19 @@ def sampling_prepare(unet, x): additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) additional_model_patchers += unet.controlnet_linked_list.get_models() + if dynamic_args.get('online_lora', False): + lora_memory = utils.nested_compute_size(unet.lora_loader.patches) + additional_inference_memory += lora_memory + memory_management.load_models_gpu( models=[unet] + additional_model_patchers, memory_required=unet_inference_memory + additional_inference_memory) + if dynamic_args.get('online_lora', False): + utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device) + + unet.lora_loader.patches = {} + real_model = unet.model percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p) diff --git a/backend/utils.py b/backend/utils.py index 20f53ff7..2d15a76a 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -79,10 +79,11 @@ def get_attr(obj, attr): def get_attr_with_parent(obj, attr): attrs = attr.split(".") parent = obj + name = None for name in attrs: parent = obj obj = getattr(obj, name) - return parent, obj + return parent, name, obj def calculate_parameters(sd, prefix=""): @@ -108,3 +109,32 @@ def fp16_fix(x): if x.dtype in [torch.float16]: return x.clip(-32768.0, 32768.0) return x + + +def nested_compute_size(obj): + module_mem = 0 + + if isinstance(obj, dict): + for key in obj: + module_mem += nested_compute_size(obj[key]) + elif isinstance(obj, list) or isinstance(obj, tuple): + for i in range(len(obj)): + module_mem += nested_compute_size(obj[i]) + elif isinstance(obj, torch.Tensor): + module_mem += obj.nelement() * obj.element_size() + + return module_mem + + +def nested_move_to_device(obj, device): + if isinstance(obj, dict): + for key in obj: + obj[key] = nested_move_to_device(obj[key], device) + elif isinstance(obj, list): + for i in range(len(obj)): + obj[i] = nested_move_to_device(obj[i], device) + elif isinstance(obj, tuple): + obj = tuple(nested_move_to_device(i, device) for i in obj) + elif isinstance(obj, torch.Tensor): + return obj.to(device) + return obj diff --git a/launch.py b/launch.py index 10aa5463..c0568c7b 100644 --- a/launch.py +++ b/launch.py @@ -1,3 +1,6 @@ +# import faulthandler +# faulthandler.enable() + from modules import launch_utils args = launch_utils.args diff --git a/modules/processing.py b/modules/processing.py index bcb41827..f1de9b6d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -799,6 +799,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: memory_management.unload_all_models() if need_global_unload: + p.sd_model.forge_objects.unet.lora_loader.dirty = True p.clear_prompt_cache() need_global_unload = False diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 04264e1f..2f52d287 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -5,6 +5,7 @@ import gradio as gr from gradio.context import Context from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths from backend import memory_management, stream +from backend.args import dynamic_args total_vram = int(memory_management.total_vram) @@ -21,11 +22,14 @@ ui_forge_pin_shared_memory: gr.Radio = None ui_forge_inference_memory: gr.Slider = None forge_unet_storage_dtype_options = { - 'Automatic': None, - 'bnb-nf4': 'nf4', - 'float8-e4m3fn': torch.float8_e4m3fn, - 'bnb-fp4': 'fp4', - 'float8-e5m2': torch.float8_e5m2, + 'Automatic': (None, False), + 'Automatic (fp16 LoRA)': (None, True), + 'bnb-nf4': ('nf4', False), + 'float8-e4m3fn': (torch.float8_e4m3fn, False), + 'float8-e4m3fn (fp16 LoRA)': (torch.float8_e4m3fn, True), + 'bnb-fp4': ('fp4', False), + 'float8-e5m2': (torch.float8_e5m2, False), + 'float8-e5m2 (fp16 LoRA)': (torch.float8_e5m2, True), } module_list = {} @@ -180,10 +184,14 @@ def refresh_model_loading_parameters(): checkpoint_info = select_checkpoint() + unet_storage_dtype, lora_fp16 = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False)) + + dynamic_args['online_lora'] = lora_fp16 + model_data.forge_loading_parameters = dict( checkpoint_info=checkpoint_info, additional_modules=shared.opts.forge_additional_modules, - unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None) + unet_storage_dtype=unet_storage_dtype ) print(f'Model selected: {model_data.forge_loading_parameters}') diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index b43d65db..a5720083 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -158,7 +158,7 @@ class __Quant(ABC): @classmethod @abstractmethod def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError + raise NotImplementedError('Low bit LoRA for this data type is not implemented yet. Please select "Automatic (fp16 LoRA)" to use this LoRA.') @classmethod @abstractmethod @@ -370,30 +370,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): return (d * qs) + m - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - # WIP - - raise NotImplementedError('Q4_1 Lora is under construction!') - - n_blocks = blocks.shape[0] - - max_vals = blocks.max(dim=-1, keepdim=True).values - min_vals = blocks.min(dim=-1, keepdim=True).values - - d = (max_vals - min_vals) / 15 - id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d) - - qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15) - - qs = qs.view(n_blocks, 2, block_size // 2) - qs = qs[:, 0, :] | (qs[:, 1, :] << 4) - - d = d.to(torch.float16).view(n_blocks, -1) - m = min_vals.to(torch.float16).view(n_blocks, -1) - - return torch.cat([d, m, qs], dim=-1) - class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): @classmethod @@ -567,31 +543,6 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): qs = (ql | (qh << 4)) return (d * qs) + m - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - # WIP - - raise NotImplementedError('Q5_1 Lora is under construction!') - - n_blocks = blocks.shape[0] - - max_val = blocks.max(dim=-1, keepdim=True)[0] - min_val = blocks.min(dim=-1, keepdim=True)[0] - - d = (max_val - min_val) / 31 - id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d) - q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8) - - qs = q.view(n_blocks, 2, block_size // 2) - qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4) - - qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte() - - d = d.to(torch.float16).view(-1, 2) - min_val = min_val.to(torch.float16).view(-1, 2) - - return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1) - class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): @classmethod @@ -677,10 +628,6 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): qs = dl * qs - ml return qs.reshape((n_blocks, -1)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): @classmethod @@ -746,10 +693,6 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): q = (ql.to(torch.int8) - (qh << 2).to(torch.int8)) return (dl * q).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): K_SCALE_SIZE = 12 @@ -826,10 +769,6 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): @classmethod @@ -876,10 +815,6 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): q = (ql | (qh << 4)) return (d * q - dm).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): @classmethod @@ -919,10 +854,6 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): ksigns: bytes = (