diff --git a/backend/memory_management.py b/backend/memory_management.py index c6ff6b95..de56e306 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -378,6 +378,7 @@ class LoadedModel: try: self.real_model = self.model.forge_patch_model(patch_model_to) + self.model.current_device = self.model.load_device except Exception as e: self.model.forge_unpatch_model(self.model.offload_device) self.model_unload() diff --git a/backend/operations.py b/backend/operations.py index 1d79a4c3..da6903cb 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -5,12 +5,53 @@ import torch import contextlib from backend import stream, memory_management, utils +from backend.patcher.lora import merge_lora_to_weight stash = {} -def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False): +def get_weight_and_bias(layer, weight_args=None, bias_args=None, weight_fn=None, bias_fn=None): + patches = getattr(layer, 'forge_online_loras', None) + weight_patches, bias_patches = None, None + + if patches is not None: + weight_patches = patches.get('weight', None) + + if patches is not None: + bias_patches = patches.get('bias', None) + + weight = None + if layer.weight is not None: + weight = layer.weight + if weight_fn is not None: + if weight_args is not None: + fn_device = weight_args.get('device', None) + if fn_device is not None: + weight = weight.to(device=fn_device) + weight = weight_fn(weight) + if weight_args is not None: + weight = weight.to(**weight_args) + if weight_patches is not None: + weight = merge_lora_to_weight(patches=weight_patches, weight=weight, key="online weight lora", computation_dtype=weight.dtype) + + bias = None + if layer.bias is not None: + bias = layer.bias + if bias_fn is not None: + if bias_args is not None: + fn_device = bias_args.get('device', None) + if fn_device is not None: + bias = bias.to(device=fn_device) + bias = bias_fn(bias) + if bias_args is not None: + bias = bias.to(**bias_args) + if bias_patches is not None: + bias = merge_lora_to_weight(patches=bias_patches, weight=bias, key="online bias lora", computation_dtype=bias.dtype) + return weight, bias + + +def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False, weight_fn=None, bias_fn=None): weight, bias, signal = None, None, None non_blocking = True @@ -32,16 +73,10 @@ def weights_manual_cast(layer, x, skip_weight_dtype=False, skip_bias_dtype=False if stream.should_use_stream(): with stream.stream_context()(stream.mover_stream): - if layer.weight is not None: - weight = layer.weight.to(**weight_args) - if layer.bias is not None: - bias = layer.bias.to(**bias_args) + weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) signal = stream.mover_stream.record_event() else: - if layer.weight is not None: - weight = layer.weight.to(**weight_args) - if layer.bias is not None: - bias = layer.bias.to(**bias_args) + weight, bias = get_weight_and_bias(layer, weight_args, bias_args, weight_fn=weight_fn, bias_fn=bias_fn) return weight, bias, signal @@ -109,7 +144,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(x, weight, bias) else: - return torch.nn.functional.linear(x, self.weight, self.bias) + weight, bias = get_weight_and_bias(self) + return torch.nn.functional.linear(x, weight, bias) class Conv2d(torch.nn.Conv2d): @@ -128,7 +164,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: - return super().forward(x) + weight, bias = get_weight_and_bias(self) + return super()._conv_forward(x, weight, bias) class Conv3d(torch.nn.Conv3d): @@ -147,7 +184,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: - return super().forward(x) + weight, bias = get_weight_and_bias(self) + return super()._conv_forward(input, weight, bias) class Conv1d(torch.nn.Conv1d): @@ -166,7 +204,8 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return self._conv_forward(x, weight, bias) else: - return super().forward(x) + weight, bias = get_weight_and_bias(self) + return super()._conv_forward(input, weight, bias) class ConvTranspose2d(torch.nn.ConvTranspose2d): @@ -188,7 +227,10 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: - return super().forward(x, output_size) + weight, bias = get_weight_and_bias(self) + num_spatial_dims = 2 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose1d(torch.nn.ConvTranspose1d): @@ -210,7 +252,10 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose1d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: - return super().forward(x, output_size) + weight, bias = get_weight_and_bias(self) + num_spatial_dims = 1 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class ConvTranspose3d(torch.nn.ConvTranspose3d): @@ -232,7 +277,10 @@ class ForgeOperations: with main_stream_worker(weight, bias, signal): return torch.nn.functional.conv_transpose3d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) else: - return super().forward(x, output_size) + weight, bias = get_weight_and_bias(self) + num_spatial_dims = 3 + output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) + return torch.nn.functional.conv_transpose2d(x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) class GroupNorm(torch.nn.GroupNorm): @@ -328,7 +376,7 @@ except: bnb_avaliable = False -from backend.operations_gguf import functional_linear_gguf +from backend.operations_gguf import dequantize_tensor class ForgeOperationsGGUF(ForgeOperations): @@ -361,12 +409,9 @@ class ForgeOperationsGGUF(ForgeOperations): return self def forward(self, x): - if self.parameters_manual_cast: - weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True) - with main_stream_worker(weight, bias, signal): - return functional_linear_gguf(x, weight, bias) - else: - return functional_linear_gguf(x, self.weight, self.bias) + weight, bias, signal = weights_manual_cast(self, x, weight_fn=dequantize_tensor, bias_fn=dequantize_tensor) + with main_stream_worker(weight, bias, signal): + return torch.nn.functional.linear(x, weight, bias) @contextlib.contextmanager diff --git a/backend/operations_gguf.py b/backend/operations_gguf.py index 495e6396..80b77382 100644 --- a/backend/operations_gguf.py +++ b/backend/operations_gguf.py @@ -58,13 +58,6 @@ class ParameterGGUF(torch.nn.Parameter): return new -def functional_linear_gguf(x, weight, bias=None): - target_dtype = x.dtype - weight = dequantize_tensor(weight).to(target_dtype) - bias = dequantize_tensor(bias).to(target_dtype) - return torch.nn.functional.linear(x, weight, bias) - - def dequantize_tensor(tensor): if tensor is None: return None diff --git a/backend/patcher/base.py b/backend/patcher/base.py index f60e60b7..39c28c60 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -4,14 +4,12 @@ # are from Forge, implemented from scratch (after forge-v1.0.1), and may have # certain level of differences. -import time -import torch + import copy import inspect -from tqdm import tqdm -from backend import memory_management, utils, operations -from backend.patcher.lora import merge_lora_to_model_weight, LoraLoader +from backend import memory_management, utils +from backend.patcher.lora import LoraLoader def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): @@ -229,7 +227,6 @@ class ModelPatcher: if target_device is not None: self.model.to(target_device) - self.current_device = target_device return self.model diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index d19556e2..01eb6930 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -5,7 +5,8 @@ import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui from tqdm import tqdm -from backend import memory_management, utils, operations +from backend import memory_management, utils +from backend.args import dynamic_args class ForgeLoraCollection: @@ -39,8 +40,10 @@ def model_lora_keys_unet(model, key_map={}): return get_function('model_lora_keys_unet')(model, key_map) -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): - dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32) +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype): + # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L33 + + dora_scale = memory_management.cast_to_device(dora_scale, weight.device, computation_dtype) lora_diff *= alpha weight_calc = weight + lora_diff.type(weight.dtype) weight_norm = ( @@ -60,7 +63,12 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): return weight -def merge_lora_to_model_weight(patches, weight, key): +def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32): + # Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446 + + weight_original_dtype = weight.dtype + weight = weight.to(dtype=computation_dtype) + for p in patches: strength = p[0] v = p[1] @@ -79,7 +87,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight *= strength_model if isinstance(v, list): - v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),) + v = (merge_lora_to_weight(v[1:], v[0].clone(), key),) patch_type = '' @@ -107,8 +115,8 @@ def merge_lora_to_model_weight(patches, weight, key): else: weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": - mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32) - mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32) + mat1 = memory_management.cast_to_device(v[0], weight.device, computation_dtype) + mat2 = memory_management.cast_to_device(v[1], weight.device, computation_dtype) dora_scale = v[4] if v[2] is not None: alpha = v[2] / mat2.shape[0] @@ -116,13 +124,13 @@ def merge_lora_to_model_weight(patches, weight, key): alpha = 1.0 if v[3] is not None: - mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32) + mat3 = memory_management.cast_to_device(v[3], weight.device, computation_dtype) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -141,23 +149,23 @@ def merge_lora_to_model_weight(patches, weight, key): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32), - memory_management.cast_to_device(w1_b, weight.device, torch.float32)) + w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, computation_dtype), + memory_management.cast_to_device(w1_b, weight.device, computation_dtype)) else: - w1 = memory_management.cast_to_device(w1, weight.device, torch.float32) + w1 = memory_management.cast_to_device(w1, weight.device, computation_dtype) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32), - memory_management.cast_to_device(w2_b, weight.device, torch.float32)) + w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, computation_dtype), + memory_management.cast_to_device(w2_b, weight.device, computation_dtype)) else: w2 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t2, weight.device, torch.float32), - memory_management.cast_to_device(w2_b, weight.device, torch.float32), - memory_management.cast_to_device(w2_a, weight.device, torch.float32)) + memory_management.cast_to_device(t2, weight.device, computation_dtype), + memory_management.cast_to_device(w2_b, weight.device, computation_dtype), + memory_management.cast_to_device(w2_a, weight.device, computation_dtype)) else: - w2 = memory_management.cast_to_device(w2, weight.device, torch.float32) + w2 = memory_management.cast_to_device(w2, weight.device, computation_dtype) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -169,7 +177,7 @@ def merge_lora_to_model_weight(patches, weight, key): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -190,24 +198,24 @@ def merge_lora_to_model_weight(patches, weight, key): t1 = v[5] t2 = v[6] m1 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t1, weight.device, torch.float32), - memory_management.cast_to_device(w1b, weight.device, torch.float32), - memory_management.cast_to_device(w1a, weight.device, torch.float32)) + memory_management.cast_to_device(t1, weight.device, computation_dtype), + memory_management.cast_to_device(w1b, weight.device, computation_dtype), + memory_management.cast_to_device(w1a, weight.device, computation_dtype)) m2 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t2, weight.device, torch.float32), - memory_management.cast_to_device(w2b, weight.device, torch.float32), - memory_management.cast_to_device(w2a, weight.device, torch.float32)) + memory_management.cast_to_device(t2, weight.device, computation_dtype), + memory_management.cast_to_device(w2b, weight.device, computation_dtype), + memory_management.cast_to_device(w2a, weight.device, computation_dtype)) else: - m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32), - memory_management.cast_to_device(w1b, weight.device, torch.float32)) - m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32), - memory_management.cast_to_device(w2b, weight.device, torch.float32)) + m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, computation_dtype), + memory_management.cast_to_device(w1b, weight.device, computation_dtype)) + m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, computation_dtype), + memory_management.cast_to_device(w2b, weight.device, computation_dtype)) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -221,15 +229,15 @@ def merge_lora_to_model_weight(patches, weight, key): dora_scale = v[5] - a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) - a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) - b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) - b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, computation_dtype) + a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, computation_dtype) + b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, computation_dtype) + b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, computation_dtype) try: lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation_dtype)) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: @@ -243,14 +251,19 @@ def merge_lora_to_model_weight(patches, weight, key): if old_weight is not None: weight = old_weight + weight = weight.to(dtype=weight_original_dtype) return weight +from backend import operations + + class LoraLoader: def __init__(self, model): self.model = model self.patches = {} self.backup = {} + self.online_backup = [] self.dirty = False def clear_patches(self): @@ -277,7 +290,7 @@ class LoraLoader: if key in model_sd: p.add(k) current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + current_patches.append([strength_patch, patches[k], strength_model, offset, function]) self.patches[key] = current_patches self.dirty = True @@ -293,10 +306,12 @@ class LoraLoader: # Restore - for k, w in self.backup.items(): - if target_device is not None: - w = w.to(device=target_device) + for m in set(self.online_backup): + del m.forge_online_loras + self.online_backup = [] + + for k, w in self.backup.items(): if not isinstance(w, torch.nn.Parameter): # In very few cases w = torch.nn.Parameter(w, requires_grad=False) @@ -305,11 +320,13 @@ class LoraLoader: self.backup = {} + online_mode = dynamic_args.get('online_lora', False) + # Patch for key, current_patches in (tqdm(self.patches.items(), desc=f'Patching LoRAs for {type(self.model).__name__}') if len(self.patches) > 0 else self.patches): try: - parent_layer, weight = utils.get_attr_with_parent(self.model, key) + parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") @@ -317,6 +334,14 @@ class LoraLoader: if key not in self.backup: self.backup[key] = weight.to(device=offload_device) + if online_mode: + if not hasattr(parent_layer, 'forge_online_loras'): + parent_layer.forge_online_loras = {} + + parent_layer.forge_online_loras[child_key] = current_patches + self.online_backup.append(parent_layer) + continue + bnb_layer = None if operations.bnb_avaliable: @@ -357,17 +382,13 @@ class LoraLoader: gguf_real_shape = weight.gguf_real_shape weight = dequantize_tensor(weight) - weight_original_dtype = weight.dtype - try: - weight = weight.to(dtype=torch.float32) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) except: print('Patching LoRA weights failed. Retrying by offloading models.') self.model.to(device=offload_device) memory_management.soft_empty_cache() - weight = weight.to(dtype=torch.float32) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + weight = merge_lora_to_weight(current_patches, weight, key, computation_dtype=torch.float32) if bnb_layer is not None: bnb_layer.reload_weight(weight) diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index c51ee742..a9fbec0c 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -10,6 +10,8 @@ import collections from backend import memory_management from backend.sampling.condition import Condition, compile_conditions, compile_weighted_conditions from backend.operations import cleanup_cache +from backend.args import dynamic_args +from backend import utils def get_area_and_mult(conds, x_in, timestep_in): @@ -353,10 +355,19 @@ def sampling_prepare(unet, x): additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) additional_model_patchers += unet.controlnet_linked_list.get_models() + if dynamic_args.get('online_lora', False): + lora_memory = utils.nested_compute_size(unet.lora_loader.patches) + additional_inference_memory += lora_memory + memory_management.load_models_gpu( models=[unet] + additional_model_patchers, memory_required=unet_inference_memory + additional_inference_memory) + if dynamic_args.get('online_lora', False): + utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device) + + unet.lora_loader.patches = {} + real_model = unet.model percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p) diff --git a/backend/utils.py b/backend/utils.py index 20f53ff7..2d15a76a 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -79,10 +79,11 @@ def get_attr(obj, attr): def get_attr_with_parent(obj, attr): attrs = attr.split(".") parent = obj + name = None for name in attrs: parent = obj obj = getattr(obj, name) - return parent, obj + return parent, name, obj def calculate_parameters(sd, prefix=""): @@ -108,3 +109,32 @@ def fp16_fix(x): if x.dtype in [torch.float16]: return x.clip(-32768.0, 32768.0) return x + + +def nested_compute_size(obj): + module_mem = 0 + + if isinstance(obj, dict): + for key in obj: + module_mem += nested_compute_size(obj[key]) + elif isinstance(obj, list) or isinstance(obj, tuple): + for i in range(len(obj)): + module_mem += nested_compute_size(obj[i]) + elif isinstance(obj, torch.Tensor): + module_mem += obj.nelement() * obj.element_size() + + return module_mem + + +def nested_move_to_device(obj, device): + if isinstance(obj, dict): + for key in obj: + obj[key] = nested_move_to_device(obj[key], device) + elif isinstance(obj, list): + for i in range(len(obj)): + obj[i] = nested_move_to_device(obj[i], device) + elif isinstance(obj, tuple): + obj = tuple(nested_move_to_device(i, device) for i in obj) + elif isinstance(obj, torch.Tensor): + return obj.to(device) + return obj diff --git a/launch.py b/launch.py index 10aa5463..c0568c7b 100644 --- a/launch.py +++ b/launch.py @@ -1,3 +1,6 @@ +# import faulthandler +# faulthandler.enable() + from modules import launch_utils args = launch_utils.args diff --git a/modules/processing.py b/modules/processing.py index bcb41827..f1de9b6d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -799,6 +799,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: memory_management.unload_all_models() if need_global_unload: + p.sd_model.forge_objects.unet.lora_loader.dirty = True p.clear_prompt_cache() need_global_unload = False diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 04264e1f..2f52d287 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -5,6 +5,7 @@ import gradio as gr from gradio.context import Context from modules import shared_items, shared, ui_common, sd_models, processing, infotext_utils, paths from backend import memory_management, stream +from backend.args import dynamic_args total_vram = int(memory_management.total_vram) @@ -21,11 +22,14 @@ ui_forge_pin_shared_memory: gr.Radio = None ui_forge_inference_memory: gr.Slider = None forge_unet_storage_dtype_options = { - 'Automatic': None, - 'bnb-nf4': 'nf4', - 'float8-e4m3fn': torch.float8_e4m3fn, - 'bnb-fp4': 'fp4', - 'float8-e5m2': torch.float8_e5m2, + 'Automatic': (None, False), + 'Automatic (fp16 LoRA)': (None, True), + 'bnb-nf4': ('nf4', False), + 'float8-e4m3fn': (torch.float8_e4m3fn, False), + 'float8-e4m3fn (fp16 LoRA)': (torch.float8_e4m3fn, True), + 'bnb-fp4': ('fp4', False), + 'float8-e5m2': (torch.float8_e5m2, False), + 'float8-e5m2 (fp16 LoRA)': (torch.float8_e5m2, True), } module_list = {} @@ -180,10 +184,14 @@ def refresh_model_loading_parameters(): checkpoint_info = select_checkpoint() + unet_storage_dtype, lora_fp16 = forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, (None, False)) + + dynamic_args['online_lora'] = lora_fp16 + model_data.forge_loading_parameters = dict( checkpoint_info=checkpoint_info, additional_modules=shared.opts.forge_additional_modules, - unet_storage_dtype=forge_unet_storage_dtype_options.get(shared.opts.forge_unet_storage_dtype, None) + unet_storage_dtype=unet_storage_dtype ) print(f'Model selected: {model_data.forge_loading_parameters}') diff --git a/packages_3rdparty/gguf/quants.py b/packages_3rdparty/gguf/quants.py index b43d65db..a5720083 100644 --- a/packages_3rdparty/gguf/quants.py +++ b/packages_3rdparty/gguf/quants.py @@ -158,7 +158,7 @@ class __Quant(ABC): @classmethod @abstractmethod def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError + raise NotImplementedError('Low bit LoRA for this data type is not implemented yet. Please select "Automatic (fp16 LoRA)" to use this LoRA.') @classmethod @abstractmethod @@ -370,30 +370,6 @@ class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): return (d * qs) + m - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - # WIP - - raise NotImplementedError('Q4_1 Lora is under construction!') - - n_blocks = blocks.shape[0] - - max_vals = blocks.max(dim=-1, keepdim=True).values - min_vals = blocks.min(dim=-1, keepdim=True).values - - d = (max_vals - min_vals) / 15 - id = torch.where(d == 0, torch.tensor(0.0, device=d.device), 1 / d) - - qs = torch.trunc((blocks - min_vals) * id + 0.5).to(torch.uint8).clip(0, 15) - - qs = qs.view(n_blocks, 2, block_size // 2) - qs = qs[:, 0, :] | (qs[:, 1, :] << 4) - - d = d.to(torch.float16).view(n_blocks, -1) - m = min_vals.to(torch.float16).view(n_blocks, -1) - - return torch.cat([d, m, qs], dim=-1) - class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): @classmethod @@ -567,31 +543,6 @@ class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): qs = (ql | (qh << 4)) return (d * qs) + m - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - # WIP - - raise NotImplementedError('Q5_1 Lora is under construction!') - - n_blocks = blocks.shape[0] - - max_val = blocks.max(dim=-1, keepdim=True)[0] - min_val = blocks.min(dim=-1, keepdim=True)[0] - - d = (max_val - min_val) / 31 - id = torch.where(d == 0, torch.zeros_like(d), 1.0 / d) - q = torch.trunc((blocks - min_val) * id + 0.5).clip(0, 31).to(torch.uint8) - - qs = q.view(n_blocks, 2, block_size // 2) - qs = (qs[..., 0, :] & 0x0F) | (qs[..., 1, :] << 4) - - qh = torch.bitwise_right_shift(q.view(n_blocks, 1, 32), torch.arange(4, dtype=torch.uint8, device=blocks.device) * 8).byte() - - d = d.to(torch.float16).view(-1, 2) - min_val = min_val.to(torch.float16).view(-1, 2) - - return torch.cat([d.byte(), min_val.byte(), qh, qs], dim=-1) - class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0): @classmethod @@ -677,10 +628,6 @@ class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): qs = dl * qs - ml return qs.reshape((n_blocks, -1)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): @classmethod @@ -746,10 +693,6 @@ class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K): q = (ql.to(torch.int8) - (qh << 2).to(torch.int8)) return (dl * q).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): K_SCALE_SIZE = 12 @@ -826,10 +769,6 @@ class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K): qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): @classmethod @@ -876,10 +815,6 @@ class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K): q = (ql | (qh << 4)) return (d * q - dm).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): @classmethod @@ -919,10 +854,6 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) - @classmethod - def quantize_blocks_pytorch(cls, blocks, block_size, type_size) -> torch.Tensor: - raise NotImplementedError('Not Implemented Yet') - class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): ksigns: bytes = (