From d1d0ec46aa20c3167d14f6f6caf371f94e93ac88 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:14:32 -0700 Subject: [PATCH] Maintain patching related 1. fix several problems related to layerdiffuse not unloaded 2. fix several problems related to Fooocus inpaint 3. Slightly speed up on-the-fly LoRAs by precomputing them to computation dtype --- backend/memory_management.py | 2 +- backend/operations.py | 4 +- backend/patcher/base.py | 40 +++++++++++ backend/patcher/clip.py | 3 + backend/patcher/lora.py | 67 +++++-------------- backend/patcher/unet.py | 5 +- backend/sampling/sampling_function.py | 12 ++-- backend/utils.py | 25 ++++--- .../scripts/forge_fooocus_inpaint.py | 8 +-- extensions-builtin/sd_forge_lora/networks.py | 52 +++++++------- modules/processing.py | 2 - 11 files changed, 119 insertions(+), 101 deletions(-) diff --git a/backend/memory_management.py b/backend/memory_management.py index 66143633..cdcbbed6 100644 --- a/backend/memory_management.py +++ b/backend/memory_management.py @@ -513,7 +513,7 @@ class LoadedModel: bake_gguf_model(self.real_model) - self.model.lora_loader.refresh(offload_device=self.model.offload_device) + self.model.refresh_loras() if is_intel_xpu() and not args.disable_ipex_hijack: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) diff --git a/backend/operations.py b/backend/operations.py index d843b678..5200093b 100644 --- a/backend/operations.py +++ b/backend/operations.py @@ -121,8 +121,10 @@ current_bnb_dtype = None class ForgeOperations: class Linear(torch.nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, in_features, out_features, *args, **kwargs): super().__init__() + self.in_features = in_features + self.out_features = out_features self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype)) self.weight = None self.bias = None diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 1592e3d0..bc27ee24 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -52,6 +52,7 @@ class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs): self.size = size self.model = model + self.lora_patches = {} self.object_patches = {} self.object_patches_backup = {} self.model_options = {"transformer_options": {}} @@ -77,6 +78,7 @@ class ModelPatcher: def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n.lora_patches = self.lora_patches.copy() n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) return n @@ -86,6 +88,44 @@ class ModelPatcher: return True return False + def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False): + lora_identifier = (filename, strength_patch, strength_model, online_mode) + this_patches = {} + + p = set() + model_keys = set(k for k, _ in self.model.named_parameters()) + + for k in patches: + offset = None + function = None + + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] + + if key in model_keys: + p.add(k) + current_patches = this_patches.get(key, []) + current_patches.append([strength_patch, patches[k], strength_model, offset, function]) + this_patches[key] = current_patches + + self.lora_patches[lora_identifier] = this_patches + return p + + def has_online_lora(self): + for (filename, strength_patch, strength_model, online_mode), this_patches in self.lora_patches.items(): + if online_mode: + return True + return False + + def refresh_loras(self): + self.lora_loader.refresh(lora_patches=self.lora_patches, offload_device=self.offload_device) + return + def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py index bb74ba13..2337397b 100644 --- a/backend/patcher/clip.py +++ b/backend/patcher/clip.py @@ -25,3 +25,6 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer return n + + def add_patches(self, *arg, **kwargs): + return self.patcher.add_patches(*arg, **kwargs) diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index a9c7e84d..f7a8819e 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -286,55 +286,24 @@ from backend import operations class LoraLoader: def __init__(self, model): self.model = model - self.patches = {} self.backup = {} self.online_backup = [] - self.dirty = False - self.online_mode = False - - def clear_patches(self): - self.patches.clear() - self.dirty = True - return - - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - model_sd = self.model.state_dict() - - for k in patches: - offset = None - function = None - - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches = self.patches.get(key, []) - current_patches.append([strength_patch, patches[k], strength_model, offset, function]) - self.patches[key] = current_patches - - self.dirty = True - - self.online_mode = dynamic_args.get('online_lora', False) - - if hasattr(self.model, 'storage_dtype'): - if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]: - self.online_mode = False - - return list(p) + self.loaded_hash = str([]) @torch.inference_mode() - def refresh(self, offload_device=torch.device('cpu')): - if not self.dirty: + def refresh(self, lora_patches, offload_device=torch.device('cpu')): + hashes = str(list(lora_patches.keys())) + + if hashes == self.loaded_hash: return - self.dirty = False + # Merge Patches + + all_patches = {} + + for (_, _, _, online_mode), patches in lora_patches.items(): + for key, current_patches in patches.items(): + all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches # Initialize @@ -362,14 +331,14 @@ class LoraLoader: # Patch - for key, current_patches in self.patches.items(): + for (key, online_mode), current_patches in all_patches.items(): try: parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key) assert isinstance(weight, torch.nn.Parameter) except: raise ValueError(f"Wrong LoRA Key: {key}") - if self.online_mode: + if online_mode: if not hasattr(parent_layer, 'forge_online_loras'): parent_layer.forge_online_loras = {} @@ -418,11 +387,5 @@ class LoraLoader: # End set_parameter_devices(self.model, parameter_devices=parameter_devices) - - if len(self.patches) > 0: - if self.online_mode: - print(f'Patched LoRAs on-the-fly; ', end='') - else: - print(f'Patched LoRAs by precomputing model weights; ', end='') - + self.loaded_hash = hashes return diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index e390a4b4..31d3051e 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -176,7 +176,7 @@ class UnetPatcher(ModelPatcher): self.set_model_patch_replace(patch, target, block_name, number, transformer_index) return - def load_frozen_patcher(self, state_dict, strength): + def load_frozen_patcher(self, filename, state_dict, strength): patch_dict = {} for k, w in state_dict.items(): model_key, patch_type, weight_index = k.split('::') @@ -191,6 +191,5 @@ class UnetPatcher(ModelPatcher): for patch_type, weight_list in v.items(): patch_flat[model_key] = (patch_type, weight_list) - self.lora_loader.clear_patches() - self.lora_loader.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) + self.add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0) return diff --git a/backend/sampling/sampling_function.py b/backend/sampling/sampling_function.py index dd8b3088..27cb1308 100644 --- a/backend/sampling/sampling_function.py +++ b/backend/sampling/sampling_function.py @@ -376,16 +376,16 @@ def sampling_prepare(unet, x): additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) additional_model_patchers += unet.controlnet_linked_list.get_models() - if unet.lora_loader.online_mode: - lora_memory = utils.nested_compute_size(unet.lora_loader.patches) + if unet.has_online_lora(): + lora_memory = utils.nested_compute_size(unet.lora_patches, element_size=utils.dtype_to_element_size(unet.model.computation_dtype)) additional_inference_memory += lora_memory memory_management.load_models_gpu( models=[unet] + additional_model_patchers, memory_required=unet_inference_memory + additional_inference_memory) - if unet.lora_loader.online_mode: - utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device) + if unet.has_online_lora(): + utils.nested_move_to_device(unet.lora_patches, device=unet.current_device, dtype=unet.model.computation_dtype) real_model = unet.model @@ -398,8 +398,8 @@ def sampling_prepare(unet, x): def sampling_cleanup(unet): - if unet.lora_loader.online_mode: - utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device) + if unet.has_online_lora(): + utils.nested_move_to_device(unet.lora_patches, device=unet.offload_device) for cnet in unet.list_controlnets(): cnet.cleanup() cleanup_cache() diff --git a/backend/utils.py b/backend/utils.py index 7c7741db..c88fceae 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -111,32 +111,39 @@ def fp16_fix(x): return x -def nested_compute_size(obj): +def dtype_to_element_size(dtype): + if isinstance(dtype, torch.dtype): + return torch.tensor([], dtype=dtype).element_size() + else: + raise ValueError(f"Invalid dtype: {dtype}") + + +def nested_compute_size(obj, element_size): module_mem = 0 if isinstance(obj, dict): for key in obj: - module_mem += nested_compute_size(obj[key]) + module_mem += nested_compute_size(obj[key], element_size) elif isinstance(obj, list) or isinstance(obj, tuple): for i in range(len(obj)): - module_mem += nested_compute_size(obj[i]) + module_mem += nested_compute_size(obj[i], element_size) elif isinstance(obj, torch.Tensor): - module_mem += obj.nelement() * obj.element_size() + module_mem += obj.nelement() * element_size return module_mem -def nested_move_to_device(obj, device): +def nested_move_to_device(obj, **kwargs): if isinstance(obj, dict): for key in obj: - obj[key] = nested_move_to_device(obj[key], device) + obj[key] = nested_move_to_device(obj[key], **kwargs) elif isinstance(obj, list): for i in range(len(obj)): - obj[i] = nested_move_to_device(obj[i], device) + obj[i] = nested_move_to_device(obj[i], **kwargs) elif isinstance(obj, tuple): - obj = tuple(nested_move_to_device(i, device) for i in obj) + obj = tuple(nested_move_to_device(i, **kwargs) for i in obj) elif isinstance(obj, torch.Tensor): - return obj.to(device) + return obj.to(**kwargs) return obj diff --git a/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py b/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py index 7355e722..5ae4ca36 100644 --- a/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py +++ b/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py @@ -55,13 +55,14 @@ class FooocusInpaintPatcher(ControlModelPatcher): def try_build_from_state_dict(state_dict, ckpt_path): if 'diffusion_model.time_embed.0.weight' in state_dict: if len(state_dict['diffusion_model.time_embed.0.weight']) == 3: - return FooocusInpaintPatcher(state_dict) + return FooocusInpaintPatcher(state_dict, ckpt_path) return None - def __init__(self, state_dict): + def __init__(self, state_dict, filename): super().__init__() self.state_dict = state_dict + self.filename = filename self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32) self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head'))) @@ -95,8 +96,7 @@ class FooocusInpaintPatcher(ControlModelPatcher): lora_keys.update({x: x for x in unet.model.state_dict().keys()}) loaded_lora = load_fooocus_patch(self.state_dict, lora_keys) - unet.lora_loader.clear_patches() # TODO - patched = unet.lora_loader.add_patches(loaded_lora, 1.0) + patched = unet.add_patches(filename=self.filename, patches=loaded_lora) not_patched_count = sum(1 for x in loaded_lora if x not in patched) diff --git a/extensions-builtin/sd_forge_lora/networks.py b/extensions-builtin/sd_forge_lora/networks.py index 55096fd8..b287c1e0 100644 --- a/extensions-builtin/sd_forge_lora/networks.py +++ b/extensions-builtin/sd_forge_lora/networks.py @@ -1,21 +1,18 @@ from __future__ import annotations -import gradio as gr -import logging + import os import re - -import functools -import network - import torch -from typing import Union +import network +import functools +from backend.args import dynamic_args from modules import shared, sd_models, errors, scripts from backend.utils import load_torch_file from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora -def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default', online_mode=False): model_flag = type(model.model).__name__ if model is not None else 'default' unet_keys = model_lora_keys_unet(model.model) if model is not None else {} @@ -32,23 +29,28 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen if len(lora_unmatch) > 0: print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}') - if model is not None and len(lora_unet) > 0: - loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model) + new_model = model.clone() if model is not None else None + new_clip = clip.clone() if clip is not None else None + + if new_model is not None and len(lora_unet) > 0: + loaded_keys = new_model.add_patches(filename=filename, patches=lora_unet, strength_patch=strength_model, online_mode=online_mode) skipped_keys = [item for item in lora_unet if item not in loaded_keys] if len(skipped_keys) > 12: print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') else: - print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)') + print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}') + model = new_model - if clip is not None and len(lora_clip) > 0: - loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip) + if new_clip is not None and len(lora_clip) > 0: + loaded_keys = new_clip.add_patches(filename=filename, patches=lora_clip, strength_patch=strength_clip, online_mode=online_mode) skipped_keys = [item for item in lora_clip if item not in loaded_keys] if len(skipped_keys) > 12: print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') else: - print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)') + print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}') + clip = new_clip - return + return model, clip @functools.lru_cache(maxsize=5) @@ -97,9 +99,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No network_on_disk.read_hash() loaded_networks.append(net) + online_mode = dynamic_args.get('online_lora', False) + + if current_sd.forge_objects.unet.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]: + online_mode = False + compiled_lora_targets = [] for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): - compiled_lora_targets.append([a.filename, b, c]) + compiled_lora_targets.append([a.filename, b, c, online_mode]) compiled_lora_targets_hash = str(compiled_lora_targets) @@ -107,15 +114,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No return current_sd.current_lora_hash = compiled_lora_targets_hash - current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone() - current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone() + current_sd.forge_objects.unet = current_sd.forge_objects_original.unet + current_sd.forge_objects.clip = current_sd.forge_objects_original.clip - current_sd.forge_objects.unet.lora_loader.clear_patches() - current_sd.forge_objects.clip.patcher.lora_loader.clear_patches() - - for filename, strength_model, strength_clip in compiled_lora_targets: + for filename, strength_model, strength_clip, online_mode in compiled_lora_targets: lora_sd = load_lora_state_dict(filename) - load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename) + current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( + current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, + filename=filename, online_mode=online_mode) current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() return diff --git a/modules/processing.py b/modules/processing.py index 8b9b46b3..e6ebb6c7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -802,8 +802,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: memory_management.unload_all_models() if need_global_unload: - p.sd_model.current_lora_hash = str([]) - p.sd_model.forge_objects.unet.lora_loader.dirty = True p.clear_prompt_cache() need_global_unload = False