Maintain patching related

1. fix several problems related to layerdiffuse not unloaded
2. fix several problems related to Fooocus inpaint
3. Slightly speed up on-the-fly LoRAs by precomputing them to computation dtype
This commit is contained in:
layerdiffusion
2024-08-30 15:14:32 -07:00
parent a8483a3f79
commit d1d0ec46aa
11 changed files with 119 additions and 101 deletions

View File

@@ -513,7 +513,7 @@ class LoadedModel:
bake_gguf_model(self.real_model)
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
self.model.refresh_loras()
if is_intel_xpu() and not args.disable_ipex_hijack:
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)

View File

@@ -121,8 +121,10 @@ current_bnb_dtype = None
class ForgeOperations:
class Linear(torch.nn.Module):
def __init__(self, *args, **kwargs):
def __init__(self, in_features, out_features, *args, **kwargs):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
self.weight = None
self.bias = None

View File

@@ -52,6 +52,7 @@ class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
self.size = size
self.model = model
self.lora_patches = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options": {}}
@@ -77,6 +78,7 @@ class ModelPatcher:
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n.lora_patches = self.lora_patches.copy()
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
return n
@@ -86,6 +88,44 @@ class ModelPatcher:
return True
return False
def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False):
lora_identifier = (filename, strength_patch, strength_model, online_mode)
this_patches = {}
p = set()
model_keys = set(k for k, _ in self.model.named_parameters())
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_keys:
p.add(k)
current_patches = this_patches.get(key, [])
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
this_patches[key] = current_patches
self.lora_patches[lora_identifier] = this_patches
return p
def has_online_lora(self):
for (filename, strength_patch, strength_model, online_mode), this_patches in self.lora_patches.items():
if online_mode:
return True
return False
def refresh_loras(self):
self.lora_loader.refresh(lora_patches=self.lora_patches, offload_device=self.offload_device)
return
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)

View File

@@ -25,3 +25,6 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
return n
def add_patches(self, *arg, **kwargs):
return self.patcher.add_patches(*arg, **kwargs)

View File

@@ -286,55 +286,24 @@ from backend import operations
class LoraLoader:
def __init__(self, model):
self.model = model
self.patches = {}
self.backup = {}
self.online_backup = []
self.dirty = False
self.online_mode = False
def clear_patches(self):
self.patches.clear()
self.dirty = True
return
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = set()
model_sd = self.model.state_dict()
for k in patches:
offset = None
function = None
if isinstance(k, str):
key = k
else:
offset = k[1]
key = k[0]
if len(k) > 2:
function = k[2]
if key in model_sd:
p.add(k)
current_patches = self.patches.get(key, [])
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
self.patches[key] = current_patches
self.dirty = True
self.online_mode = dynamic_args.get('online_lora', False)
if hasattr(self.model, 'storage_dtype'):
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self.online_mode = False
return list(p)
self.loaded_hash = str([])
@torch.inference_mode()
def refresh(self, offload_device=torch.device('cpu')):
if not self.dirty:
def refresh(self, lora_patches, offload_device=torch.device('cpu')):
hashes = str(list(lora_patches.keys()))
if hashes == self.loaded_hash:
return
self.dirty = False
# Merge Patches
all_patches = {}
for (_, _, _, online_mode), patches in lora_patches.items():
for key, current_patches in patches.items():
all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches
# Initialize
@@ -362,14 +331,14 @@ class LoraLoader:
# Patch
for key, current_patches in self.patches.items():
for (key, online_mode), current_patches in all_patches.items():
try:
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
assert isinstance(weight, torch.nn.Parameter)
except:
raise ValueError(f"Wrong LoRA Key: {key}")
if self.online_mode:
if online_mode:
if not hasattr(parent_layer, 'forge_online_loras'):
parent_layer.forge_online_loras = {}
@@ -418,11 +387,5 @@ class LoraLoader:
# End
set_parameter_devices(self.model, parameter_devices=parameter_devices)
if len(self.patches) > 0:
if self.online_mode:
print(f'Patched LoRAs on-the-fly; ', end='')
else:
print(f'Patched LoRAs by precomputing model weights; ', end='')
self.loaded_hash = hashes
return

View File

@@ -176,7 +176,7 @@ class UnetPatcher(ModelPatcher):
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
return
def load_frozen_patcher(self, state_dict, strength):
def load_frozen_patcher(self, filename, state_dict, strength):
patch_dict = {}
for k, w in state_dict.items():
model_key, patch_type, weight_index = k.split('::')
@@ -191,6 +191,5 @@ class UnetPatcher(ModelPatcher):
for patch_type, weight_list in v.items():
patch_flat[model_key] = (patch_type, weight_list)
self.lora_loader.clear_patches()
self.lora_loader.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
self.add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
return

View File

@@ -376,16 +376,16 @@ def sampling_prepare(unet, x):
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
additional_model_patchers += unet.controlnet_linked_list.get_models()
if unet.lora_loader.online_mode:
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
if unet.has_online_lora():
lora_memory = utils.nested_compute_size(unet.lora_patches, element_size=utils.dtype_to_element_size(unet.model.computation_dtype))
additional_inference_memory += lora_memory
memory_management.load_models_gpu(
models=[unet] + additional_model_patchers,
memory_required=unet_inference_memory + additional_inference_memory)
if unet.lora_loader.online_mode:
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
if unet.has_online_lora():
utils.nested_move_to_device(unet.lora_patches, device=unet.current_device, dtype=unet.model.computation_dtype)
real_model = unet.model
@@ -398,8 +398,8 @@ def sampling_prepare(unet, x):
def sampling_cleanup(unet):
if unet.lora_loader.online_mode:
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
if unet.has_online_lora():
utils.nested_move_to_device(unet.lora_patches, device=unet.offload_device)
for cnet in unet.list_controlnets():
cnet.cleanup()
cleanup_cache()

View File

@@ -111,32 +111,39 @@ def fp16_fix(x):
return x
def nested_compute_size(obj):
def dtype_to_element_size(dtype):
if isinstance(dtype, torch.dtype):
return torch.tensor([], dtype=dtype).element_size()
else:
raise ValueError(f"Invalid dtype: {dtype}")
def nested_compute_size(obj, element_size):
module_mem = 0
if isinstance(obj, dict):
for key in obj:
module_mem += nested_compute_size(obj[key])
module_mem += nested_compute_size(obj[key], element_size)
elif isinstance(obj, list) or isinstance(obj, tuple):
for i in range(len(obj)):
module_mem += nested_compute_size(obj[i])
module_mem += nested_compute_size(obj[i], element_size)
elif isinstance(obj, torch.Tensor):
module_mem += obj.nelement() * obj.element_size()
module_mem += obj.nelement() * element_size
return module_mem
def nested_move_to_device(obj, device):
def nested_move_to_device(obj, **kwargs):
if isinstance(obj, dict):
for key in obj:
obj[key] = nested_move_to_device(obj[key], device)
obj[key] = nested_move_to_device(obj[key], **kwargs)
elif isinstance(obj, list):
for i in range(len(obj)):
obj[i] = nested_move_to_device(obj[i], device)
obj[i] = nested_move_to_device(obj[i], **kwargs)
elif isinstance(obj, tuple):
obj = tuple(nested_move_to_device(i, device) for i in obj)
obj = tuple(nested_move_to_device(i, **kwargs) for i in obj)
elif isinstance(obj, torch.Tensor):
return obj.to(device)
return obj.to(**kwargs)
return obj

View File

@@ -55,13 +55,14 @@ class FooocusInpaintPatcher(ControlModelPatcher):
def try_build_from_state_dict(state_dict, ckpt_path):
if 'diffusion_model.time_embed.0.weight' in state_dict:
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
return FooocusInpaintPatcher(state_dict)
return FooocusInpaintPatcher(state_dict, ckpt_path)
return None
def __init__(self, state_dict):
def __init__(self, state_dict, filename):
super().__init__()
self.state_dict = state_dict
self.filename = filename
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
@@ -95,8 +96,7 @@ class FooocusInpaintPatcher(ControlModelPatcher):
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
unet.lora_loader.clear_patches() # TODO
patched = unet.lora_loader.add_patches(loaded_lora, 1.0)
patched = unet.add_patches(filename=self.filename, patches=loaded_lora)
not_patched_count = sum(1 for x in loaded_lora if x not in patched)

View File

@@ -1,21 +1,18 @@
from __future__ import annotations
import gradio as gr
import logging
import os
import re
import functools
import network
import torch
from typing import Union
import network
import functools
from backend.args import dynamic_args
from modules import shared, sd_models, errors, scripts
from backend.utils import load_torch_file
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default', online_mode=False):
model_flag = type(model.model).__name__ if model is not None else 'default'
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
@@ -32,23 +29,28 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
if len(lora_unmatch) > 0:
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
if model is not None and len(lora_unet) > 0:
loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model)
new_model = model.clone() if model is not None else None
new_clip = clip.clone() if clip is not None else None
if new_model is not None and len(lora_unet) > 0:
loaded_keys = new_model.add_patches(filename=filename, patches=lora_unet, strength_patch=strength_model, online_mode=online_mode)
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
if len(skipped_keys) > 12:
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
else:
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}')
model = new_model
if clip is not None and len(lora_clip) > 0:
loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip)
if new_clip is not None and len(lora_clip) > 0:
loaded_keys = new_clip.add_patches(filename=filename, patches=lora_clip, strength_patch=strength_clip, online_mode=online_mode)
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
if len(skipped_keys) > 12:
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
else:
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}')
clip = new_clip
return
return model, clip
@functools.lru_cache(maxsize=5)
@@ -97,9 +99,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
network_on_disk.read_hash()
loaded_networks.append(net)
online_mode = dynamic_args.get('online_lora', False)
if current_sd.forge_objects.unet.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
online_mode = False
compiled_lora_targets = []
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
compiled_lora_targets.append([a.filename, b, c])
compiled_lora_targets.append([a.filename, b, c, online_mode])
compiled_lora_targets_hash = str(compiled_lora_targets)
@@ -107,15 +114,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
return
current_sd.current_lora_hash = compiled_lora_targets_hash
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone()
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone()
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
current_sd.forge_objects.unet.lora_loader.clear_patches()
current_sd.forge_objects.clip.patcher.lora_loader.clear_patches()
for filename, strength_model, strength_clip in compiled_lora_targets:
for filename, strength_model, strength_clip, online_mode in compiled_lora_targets:
lora_sd = load_lora_state_dict(filename)
load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename)
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
filename=filename, online_mode=online_mode)
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
return

View File

@@ -802,8 +802,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
memory_management.unload_all_models()
if need_global_unload:
p.sd_model.current_lora_hash = str([])
p.sd_model.forge_objects.unet.lora_loader.dirty = True
p.clear_prompt_cache()
need_global_unload = False