mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-14 17:37:25 +00:00
Maintain patching related
1. fix several problems related to layerdiffuse not unloaded 2. fix several problems related to Fooocus inpaint 3. Slightly speed up on-the-fly LoRAs by precomputing them to computation dtype
This commit is contained in:
@@ -513,7 +513,7 @@ class LoadedModel:
|
||||
|
||||
bake_gguf_model(self.real_model)
|
||||
|
||||
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
|
||||
self.model.refresh_loras()
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_hijack:
|
||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||
|
||||
@@ -121,8 +121,10 @@ current_bnb_dtype = None
|
||||
|
||||
class ForgeOperations:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, in_features, out_features, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
@@ -52,6 +52,7 @@ class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.lora_patches = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options": {}}
|
||||
@@ -77,6 +78,7 @@ class ModelPatcher:
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n.lora_patches = self.lora_patches.copy()
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
@@ -86,6 +88,44 @@ class ModelPatcher:
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False):
|
||||
lora_identifier = (filename, strength_patch, strength_model, online_mode)
|
||||
this_patches = {}
|
||||
|
||||
p = set()
|
||||
model_keys = set(k for k, _ in self.model.named_parameters())
|
||||
|
||||
for k in patches:
|
||||
offset = None
|
||||
function = None
|
||||
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
if key in model_keys:
|
||||
p.add(k)
|
||||
current_patches = this_patches.get(key, [])
|
||||
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
||||
this_patches[key] = current_patches
|
||||
|
||||
self.lora_patches[lora_identifier] = this_patches
|
||||
return p
|
||||
|
||||
def has_online_lora(self):
|
||||
for (filename, strength_patch, strength_model, online_mode), this_patches in self.lora_patches.items():
|
||||
if online_mode:
|
||||
return True
|
||||
return False
|
||||
|
||||
def refresh_loras(self):
|
||||
self.lora_loader.refresh(lora_patches=self.lora_patches, offload_device=self.offload_device)
|
||||
return
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
|
||||
@@ -25,3 +25,6 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
return n
|
||||
|
||||
def add_patches(self, *arg, **kwargs):
|
||||
return self.patcher.add_patches(*arg, **kwargs)
|
||||
|
||||
@@ -286,55 +286,24 @@ from backend import operations
|
||||
class LoraLoader:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.online_backup = []
|
||||
self.dirty = False
|
||||
self.online_mode = False
|
||||
|
||||
def clear_patches(self):
|
||||
self.patches.clear()
|
||||
self.dirty = True
|
||||
return
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
model_sd = self.model.state_dict()
|
||||
|
||||
for k in patches:
|
||||
offset = None
|
||||
function = None
|
||||
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.dirty = True
|
||||
|
||||
self.online_mode = dynamic_args.get('online_lora', False)
|
||||
|
||||
if hasattr(self.model, 'storage_dtype'):
|
||||
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
self.online_mode = False
|
||||
|
||||
return list(p)
|
||||
self.loaded_hash = str([])
|
||||
|
||||
@torch.inference_mode()
|
||||
def refresh(self, offload_device=torch.device('cpu')):
|
||||
if not self.dirty:
|
||||
def refresh(self, lora_patches, offload_device=torch.device('cpu')):
|
||||
hashes = str(list(lora_patches.keys()))
|
||||
|
||||
if hashes == self.loaded_hash:
|
||||
return
|
||||
|
||||
self.dirty = False
|
||||
# Merge Patches
|
||||
|
||||
all_patches = {}
|
||||
|
||||
for (_, _, _, online_mode), patches in lora_patches.items():
|
||||
for key, current_patches in patches.items():
|
||||
all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches
|
||||
|
||||
# Initialize
|
||||
|
||||
@@ -362,14 +331,14 @@ class LoraLoader:
|
||||
|
||||
# Patch
|
||||
|
||||
for key, current_patches in self.patches.items():
|
||||
for (key, online_mode), current_patches in all_patches.items():
|
||||
try:
|
||||
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
|
||||
if self.online_mode:
|
||||
if online_mode:
|
||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||
parent_layer.forge_online_loras = {}
|
||||
|
||||
@@ -418,11 +387,5 @@ class LoraLoader:
|
||||
# End
|
||||
|
||||
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
||||
|
||||
if len(self.patches) > 0:
|
||||
if self.online_mode:
|
||||
print(f'Patched LoRAs on-the-fly; ', end='')
|
||||
else:
|
||||
print(f'Patched LoRAs by precomputing model weights; ', end='')
|
||||
|
||||
self.loaded_hash = hashes
|
||||
return
|
||||
|
||||
@@ -176,7 +176,7 @@ class UnetPatcher(ModelPatcher):
|
||||
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
||||
return
|
||||
|
||||
def load_frozen_patcher(self, state_dict, strength):
|
||||
def load_frozen_patcher(self, filename, state_dict, strength):
|
||||
patch_dict = {}
|
||||
for k, w in state_dict.items():
|
||||
model_key, patch_type, weight_index = k.split('::')
|
||||
@@ -191,6 +191,5 @@ class UnetPatcher(ModelPatcher):
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
self.lora_loader.clear_patches()
|
||||
self.lora_loader.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
self.add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
return
|
||||
|
||||
@@ -376,16 +376,16 @@ def sampling_prepare(unet, x):
|
||||
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
||||
|
||||
if unet.lora_loader.online_mode:
|
||||
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
||||
if unet.has_online_lora():
|
||||
lora_memory = utils.nested_compute_size(unet.lora_patches, element_size=utils.dtype_to_element_size(unet.model.computation_dtype))
|
||||
additional_inference_memory += lora_memory
|
||||
|
||||
memory_management.load_models_gpu(
|
||||
models=[unet] + additional_model_patchers,
|
||||
memory_required=unet_inference_memory + additional_inference_memory)
|
||||
|
||||
if unet.lora_loader.online_mode:
|
||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
||||
if unet.has_online_lora():
|
||||
utils.nested_move_to_device(unet.lora_patches, device=unet.current_device, dtype=unet.model.computation_dtype)
|
||||
|
||||
real_model = unet.model
|
||||
|
||||
@@ -398,8 +398,8 @@ def sampling_prepare(unet, x):
|
||||
|
||||
|
||||
def sampling_cleanup(unet):
|
||||
if unet.lora_loader.online_mode:
|
||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
|
||||
if unet.has_online_lora():
|
||||
utils.nested_move_to_device(unet.lora_patches, device=unet.offload_device)
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.cleanup()
|
||||
cleanup_cache()
|
||||
|
||||
@@ -111,32 +111,39 @@ def fp16_fix(x):
|
||||
return x
|
||||
|
||||
|
||||
def nested_compute_size(obj):
|
||||
def dtype_to_element_size(dtype):
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
else:
|
||||
raise ValueError(f"Invalid dtype: {dtype}")
|
||||
|
||||
|
||||
def nested_compute_size(obj, element_size):
|
||||
module_mem = 0
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for key in obj:
|
||||
module_mem += nested_compute_size(obj[key])
|
||||
module_mem += nested_compute_size(obj[key], element_size)
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for i in range(len(obj)):
|
||||
module_mem += nested_compute_size(obj[i])
|
||||
module_mem += nested_compute_size(obj[i], element_size)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
module_mem += obj.nelement() * obj.element_size()
|
||||
module_mem += obj.nelement() * element_size
|
||||
|
||||
return module_mem
|
||||
|
||||
|
||||
def nested_move_to_device(obj, device):
|
||||
def nested_move_to_device(obj, **kwargs):
|
||||
if isinstance(obj, dict):
|
||||
for key in obj:
|
||||
obj[key] = nested_move_to_device(obj[key], device)
|
||||
obj[key] = nested_move_to_device(obj[key], **kwargs)
|
||||
elif isinstance(obj, list):
|
||||
for i in range(len(obj)):
|
||||
obj[i] = nested_move_to_device(obj[i], device)
|
||||
obj[i] = nested_move_to_device(obj[i], **kwargs)
|
||||
elif isinstance(obj, tuple):
|
||||
obj = tuple(nested_move_to_device(i, device) for i in obj)
|
||||
obj = tuple(nested_move_to_device(i, **kwargs) for i in obj)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return obj.to(device)
|
||||
return obj.to(**kwargs)
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
@@ -55,13 +55,14 @@ class FooocusInpaintPatcher(ControlModelPatcher):
|
||||
def try_build_from_state_dict(state_dict, ckpt_path):
|
||||
if 'diffusion_model.time_embed.0.weight' in state_dict:
|
||||
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
|
||||
return FooocusInpaintPatcher(state_dict)
|
||||
return FooocusInpaintPatcher(state_dict, ckpt_path)
|
||||
|
||||
return None
|
||||
|
||||
def __init__(self, state_dict):
|
||||
def __init__(self, state_dict, filename):
|
||||
super().__init__()
|
||||
self.state_dict = state_dict
|
||||
self.filename = filename
|
||||
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
|
||||
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
|
||||
|
||||
@@ -95,8 +96,7 @@ class FooocusInpaintPatcher(ControlModelPatcher):
|
||||
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
|
||||
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
|
||||
|
||||
unet.lora_loader.clear_patches() # TODO
|
||||
patched = unet.lora_loader.add_patches(loaded_lora, 1.0)
|
||||
patched = unet.add_patches(filename=self.filename, patches=loaded_lora)
|
||||
|
||||
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
|
||||
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import gradio as gr
|
||||
import logging
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import functools
|
||||
import network
|
||||
|
||||
import torch
|
||||
from typing import Union
|
||||
import network
|
||||
import functools
|
||||
|
||||
from backend.args import dynamic_args
|
||||
from modules import shared, sd_models, errors, scripts
|
||||
from backend.utils import load_torch_file
|
||||
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default', online_mode=False):
|
||||
model_flag = type(model.model).__name__ if model is not None else 'default'
|
||||
|
||||
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
|
||||
@@ -32,23 +29,28 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
|
||||
if len(lora_unmatch) > 0:
|
||||
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
|
||||
|
||||
if model is not None and len(lora_unet) > 0:
|
||||
loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model)
|
||||
new_model = model.clone() if model is not None else None
|
||||
new_clip = clip.clone() if clip is not None else None
|
||||
|
||||
if new_model is not None and len(lora_unet) > 0:
|
||||
loaded_keys = new_model.add_patches(filename=filename, patches=lora_unet, strength_patch=strength_model, online_mode=online_mode)
|
||||
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
|
||||
if len(skipped_keys) > 12:
|
||||
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||
else:
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}')
|
||||
model = new_model
|
||||
|
||||
if clip is not None and len(lora_clip) > 0:
|
||||
loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip)
|
||||
if new_clip is not None and len(lora_clip) > 0:
|
||||
loaded_keys = new_clip.add_patches(filename=filename, patches=lora_clip, strength_patch=strength_clip, online_mode=online_mode)
|
||||
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
|
||||
if len(skipped_keys) > 12:
|
||||
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||
else:
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}')
|
||||
clip = new_clip
|
||||
|
||||
return
|
||||
return model, clip
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=5)
|
||||
@@ -97,9 +99,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
network_on_disk.read_hash()
|
||||
loaded_networks.append(net)
|
||||
|
||||
online_mode = dynamic_args.get('online_lora', False)
|
||||
|
||||
if current_sd.forge_objects.unet.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
online_mode = False
|
||||
|
||||
compiled_lora_targets = []
|
||||
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
||||
compiled_lora_targets.append([a.filename, b, c])
|
||||
compiled_lora_targets.append([a.filename, b, c, online_mode])
|
||||
|
||||
compiled_lora_targets_hash = str(compiled_lora_targets)
|
||||
|
||||
@@ -107,15 +114,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
return
|
||||
|
||||
current_sd.current_lora_hash = compiled_lora_targets_hash
|
||||
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone()
|
||||
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone()
|
||||
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
||||
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
||||
|
||||
current_sd.forge_objects.unet.lora_loader.clear_patches()
|
||||
current_sd.forge_objects.clip.patcher.lora_loader.clear_patches()
|
||||
|
||||
for filename, strength_model, strength_clip in compiled_lora_targets:
|
||||
for filename, strength_model, strength_clip, online_mode in compiled_lora_targets:
|
||||
lora_sd = load_lora_state_dict(filename)
|
||||
load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename)
|
||||
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
|
||||
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
||||
filename=filename, online_mode=online_mode)
|
||||
|
||||
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
||||
return
|
||||
|
||||
@@ -802,8 +802,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
memory_management.unload_all_models()
|
||||
|
||||
if need_global_unload:
|
||||
p.sd_model.current_lora_hash = str([])
|
||||
p.sd_model.forge_objects.unet.lora_loader.dirty = True
|
||||
p.clear_prompt_cache()
|
||||
|
||||
need_global_unload = False
|
||||
|
||||
Reference in New Issue
Block a user