mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
Maintain patching related
1. fix several problems related to layerdiffuse not unloaded 2. fix several problems related to Fooocus inpaint 3. Slightly speed up on-the-fly LoRAs by precomputing them to computation dtype
This commit is contained in:
@@ -513,7 +513,7 @@ class LoadedModel:
|
|||||||
|
|
||||||
bake_gguf_model(self.real_model)
|
bake_gguf_model(self.real_model)
|
||||||
|
|
||||||
self.model.lora_loader.refresh(offload_device=self.model.offload_device)
|
self.model.refresh_loras()
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_hijack:
|
if is_intel_xpu() and not args.disable_ipex_hijack:
|
||||||
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||||
|
|||||||
@@ -121,8 +121,10 @@ current_bnb_dtype = None
|
|||||||
|
|
||||||
class ForgeOperations:
|
class ForgeOperations:
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, in_features, out_features, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
self.dummy = torch.nn.Parameter(torch.empty(1, device=current_device, dtype=current_dtype))
|
||||||
self.weight = None
|
self.weight = None
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class ModelPatcher:
|
|||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.lora_patches = {}
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
self.model_options = {"transformer_options": {}}
|
self.model_options = {"transformer_options": {}}
|
||||||
@@ -77,6 +78,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||||
|
n.lora_patches = self.lora_patches.copy()
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
return n
|
return n
|
||||||
@@ -86,6 +88,44 @@ class ModelPatcher:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def add_patches(self, *, filename, patches, strength_patch=1.0, strength_model=1.0, online_mode=False):
|
||||||
|
lora_identifier = (filename, strength_patch, strength_model, online_mode)
|
||||||
|
this_patches = {}
|
||||||
|
|
||||||
|
p = set()
|
||||||
|
model_keys = set(k for k, _ in self.model.named_parameters())
|
||||||
|
|
||||||
|
for k in patches:
|
||||||
|
offset = None
|
||||||
|
function = None
|
||||||
|
|
||||||
|
if isinstance(k, str):
|
||||||
|
key = k
|
||||||
|
else:
|
||||||
|
offset = k[1]
|
||||||
|
key = k[0]
|
||||||
|
if len(k) > 2:
|
||||||
|
function = k[2]
|
||||||
|
|
||||||
|
if key in model_keys:
|
||||||
|
p.add(k)
|
||||||
|
current_patches = this_patches.get(key, [])
|
||||||
|
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
||||||
|
this_patches[key] = current_patches
|
||||||
|
|
||||||
|
self.lora_patches[lora_identifier] = this_patches
|
||||||
|
return p
|
||||||
|
|
||||||
|
def has_online_lora(self):
|
||||||
|
for (filename, strength_patch, strength_model, online_mode), this_patches in self.lora_patches.items():
|
||||||
|
if online_mode:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def refresh_loras(self):
|
||||||
|
self.lora_loader.refresh(lora_patches=self.lora_patches, offload_device=self.offload_device)
|
||||||
|
return
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
|
|||||||
@@ -25,3 +25,6 @@ class CLIP:
|
|||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def add_patches(self, *arg, **kwargs):
|
||||||
|
return self.patcher.add_patches(*arg, **kwargs)
|
||||||
|
|||||||
@@ -286,55 +286,24 @@ from backend import operations
|
|||||||
class LoraLoader:
|
class LoraLoader:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.patches = {}
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.online_backup = []
|
self.online_backup = []
|
||||||
self.dirty = False
|
self.loaded_hash = str([])
|
||||||
self.online_mode = False
|
|
||||||
|
|
||||||
def clear_patches(self):
|
|
||||||
self.patches.clear()
|
|
||||||
self.dirty = True
|
|
||||||
return
|
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
||||||
p = set()
|
|
||||||
model_sd = self.model.state_dict()
|
|
||||||
|
|
||||||
for k in patches:
|
|
||||||
offset = None
|
|
||||||
function = None
|
|
||||||
|
|
||||||
if isinstance(k, str):
|
|
||||||
key = k
|
|
||||||
else:
|
|
||||||
offset = k[1]
|
|
||||||
key = k[0]
|
|
||||||
if len(k) > 2:
|
|
||||||
function = k[2]
|
|
||||||
|
|
||||||
if key in model_sd:
|
|
||||||
p.add(k)
|
|
||||||
current_patches = self.patches.get(key, [])
|
|
||||||
current_patches.append([strength_patch, patches[k], strength_model, offset, function])
|
|
||||||
self.patches[key] = current_patches
|
|
||||||
|
|
||||||
self.dirty = True
|
|
||||||
|
|
||||||
self.online_mode = dynamic_args.get('online_lora', False)
|
|
||||||
|
|
||||||
if hasattr(self.model, 'storage_dtype'):
|
|
||||||
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
|
||||||
self.online_mode = False
|
|
||||||
|
|
||||||
return list(p)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh(self, offload_device=torch.device('cpu')):
|
def refresh(self, lora_patches, offload_device=torch.device('cpu')):
|
||||||
if not self.dirty:
|
hashes = str(list(lora_patches.keys()))
|
||||||
|
|
||||||
|
if hashes == self.loaded_hash:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.dirty = False
|
# Merge Patches
|
||||||
|
|
||||||
|
all_patches = {}
|
||||||
|
|
||||||
|
for (_, _, _, online_mode), patches in lora_patches.items():
|
||||||
|
for key, current_patches in patches.items():
|
||||||
|
all_patches[(key, online_mode)] = all_patches.get((key, online_mode), []) + current_patches
|
||||||
|
|
||||||
# Initialize
|
# Initialize
|
||||||
|
|
||||||
@@ -362,14 +331,14 @@ class LoraLoader:
|
|||||||
|
|
||||||
# Patch
|
# Patch
|
||||||
|
|
||||||
for key, current_patches in self.patches.items():
|
for (key, online_mode), current_patches in all_patches.items():
|
||||||
try:
|
try:
|
||||||
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
parent_layer, child_key, weight = utils.get_attr_with_parent(self.model, key)
|
||||||
assert isinstance(weight, torch.nn.Parameter)
|
assert isinstance(weight, torch.nn.Parameter)
|
||||||
except:
|
except:
|
||||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||||
|
|
||||||
if self.online_mode:
|
if online_mode:
|
||||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||||
parent_layer.forge_online_loras = {}
|
parent_layer.forge_online_loras = {}
|
||||||
|
|
||||||
@@ -418,11 +387,5 @@ class LoraLoader:
|
|||||||
# End
|
# End
|
||||||
|
|
||||||
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
set_parameter_devices(self.model, parameter_devices=parameter_devices)
|
||||||
|
self.loaded_hash = hashes
|
||||||
if len(self.patches) > 0:
|
|
||||||
if self.online_mode:
|
|
||||||
print(f'Patched LoRAs on-the-fly; ', end='')
|
|
||||||
else:
|
|
||||||
print(f'Patched LoRAs by precomputing model weights; ', end='')
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class UnetPatcher(ModelPatcher):
|
|||||||
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_frozen_patcher(self, state_dict, strength):
|
def load_frozen_patcher(self, filename, state_dict, strength):
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
for k, w in state_dict.items():
|
for k, w in state_dict.items():
|
||||||
model_key, patch_type, weight_index = k.split('::')
|
model_key, patch_type, weight_index = k.split('::')
|
||||||
@@ -191,6 +191,5 @@ class UnetPatcher(ModelPatcher):
|
|||||||
for patch_type, weight_list in v.items():
|
for patch_type, weight_list in v.items():
|
||||||
patch_flat[model_key] = (patch_type, weight_list)
|
patch_flat[model_key] = (patch_type, weight_list)
|
||||||
|
|
||||||
self.lora_loader.clear_patches()
|
self.add_patches(filename=filename, patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||||
self.lora_loader.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -376,16 +376,16 @@ def sampling_prepare(unet, x):
|
|||||||
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||||
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
||||||
|
|
||||||
if unet.lora_loader.online_mode:
|
if unet.has_online_lora():
|
||||||
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
lora_memory = utils.nested_compute_size(unet.lora_patches, element_size=utils.dtype_to_element_size(unet.model.computation_dtype))
|
||||||
additional_inference_memory += lora_memory
|
additional_inference_memory += lora_memory
|
||||||
|
|
||||||
memory_management.load_models_gpu(
|
memory_management.load_models_gpu(
|
||||||
models=[unet] + additional_model_patchers,
|
models=[unet] + additional_model_patchers,
|
||||||
memory_required=unet_inference_memory + additional_inference_memory)
|
memory_required=unet_inference_memory + additional_inference_memory)
|
||||||
|
|
||||||
if unet.lora_loader.online_mode:
|
if unet.has_online_lora():
|
||||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
utils.nested_move_to_device(unet.lora_patches, device=unet.current_device, dtype=unet.model.computation_dtype)
|
||||||
|
|
||||||
real_model = unet.model
|
real_model = unet.model
|
||||||
|
|
||||||
@@ -398,8 +398,8 @@ def sampling_prepare(unet, x):
|
|||||||
|
|
||||||
|
|
||||||
def sampling_cleanup(unet):
|
def sampling_cleanup(unet):
|
||||||
if unet.lora_loader.online_mode:
|
if unet.has_online_lora():
|
||||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
|
utils.nested_move_to_device(unet.lora_patches, device=unet.offload_device)
|
||||||
for cnet in unet.list_controlnets():
|
for cnet in unet.list_controlnets():
|
||||||
cnet.cleanup()
|
cnet.cleanup()
|
||||||
cleanup_cache()
|
cleanup_cache()
|
||||||
|
|||||||
@@ -111,32 +111,39 @@ def fp16_fix(x):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def nested_compute_size(obj):
|
def dtype_to_element_size(dtype):
|
||||||
|
if isinstance(dtype, torch.dtype):
|
||||||
|
return torch.tensor([], dtype=dtype).element_size()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def nested_compute_size(obj, element_size):
|
||||||
module_mem = 0
|
module_mem = 0
|
||||||
|
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
for key in obj:
|
for key in obj:
|
||||||
module_mem += nested_compute_size(obj[key])
|
module_mem += nested_compute_size(obj[key], element_size)
|
||||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||||
for i in range(len(obj)):
|
for i in range(len(obj)):
|
||||||
module_mem += nested_compute_size(obj[i])
|
module_mem += nested_compute_size(obj[i], element_size)
|
||||||
elif isinstance(obj, torch.Tensor):
|
elif isinstance(obj, torch.Tensor):
|
||||||
module_mem += obj.nelement() * obj.element_size()
|
module_mem += obj.nelement() * element_size
|
||||||
|
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
|
|
||||||
def nested_move_to_device(obj, device):
|
def nested_move_to_device(obj, **kwargs):
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
for key in obj:
|
for key in obj:
|
||||||
obj[key] = nested_move_to_device(obj[key], device)
|
obj[key] = nested_move_to_device(obj[key], **kwargs)
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
for i in range(len(obj)):
|
for i in range(len(obj)):
|
||||||
obj[i] = nested_move_to_device(obj[i], device)
|
obj[i] = nested_move_to_device(obj[i], **kwargs)
|
||||||
elif isinstance(obj, tuple):
|
elif isinstance(obj, tuple):
|
||||||
obj = tuple(nested_move_to_device(i, device) for i in obj)
|
obj = tuple(nested_move_to_device(i, **kwargs) for i in obj)
|
||||||
elif isinstance(obj, torch.Tensor):
|
elif isinstance(obj, torch.Tensor):
|
||||||
return obj.to(device)
|
return obj.to(**kwargs)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,13 +55,14 @@ class FooocusInpaintPatcher(ControlModelPatcher):
|
|||||||
def try_build_from_state_dict(state_dict, ckpt_path):
|
def try_build_from_state_dict(state_dict, ckpt_path):
|
||||||
if 'diffusion_model.time_embed.0.weight' in state_dict:
|
if 'diffusion_model.time_embed.0.weight' in state_dict:
|
||||||
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
|
if len(state_dict['diffusion_model.time_embed.0.weight']) == 3:
|
||||||
return FooocusInpaintPatcher(state_dict)
|
return FooocusInpaintPatcher(state_dict, ckpt_path)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __init__(self, state_dict):
|
def __init__(self, state_dict, filename):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_dict = state_dict
|
self.state_dict = state_dict
|
||||||
|
self.filename = filename
|
||||||
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
|
self.inpaint_head = InpaintHead().to(device=torch.device('cpu'), dtype=torch.float32)
|
||||||
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
|
self.inpaint_head.load_state_dict(load_torch_file(os.path.join(os.path.dirname(__file__), 'fooocus_inpaint_head')))
|
||||||
|
|
||||||
@@ -95,8 +96,7 @@ class FooocusInpaintPatcher(ControlModelPatcher):
|
|||||||
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
|
lora_keys.update({x: x for x in unet.model.state_dict().keys()})
|
||||||
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
|
loaded_lora = load_fooocus_patch(self.state_dict, lora_keys)
|
||||||
|
|
||||||
unet.lora_loader.clear_patches() # TODO
|
patched = unet.add_patches(filename=self.filename, patches=loaded_lora)
|
||||||
patched = unet.lora_loader.add_patches(loaded_lora, 1.0)
|
|
||||||
|
|
||||||
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
|
not_patched_count = sum(1 for x in loaded_lora if x not in patched)
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import gradio as gr
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import functools
|
|
||||||
import network
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
import network
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from backend.args import dynamic_args
|
||||||
from modules import shared, sd_models, errors, scripts
|
from modules import shared, sd_models, errors, scripts
|
||||||
from backend.utils import load_torch_file
|
from backend.utils import load_torch_file
|
||||||
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora
|
from backend.patcher.lora import model_lora_keys_clip, model_lora_keys_unet, load_lora
|
||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default', online_mode=False):
|
||||||
model_flag = type(model.model).__name__ if model is not None else 'default'
|
model_flag = type(model.model).__name__ if model is not None else 'default'
|
||||||
|
|
||||||
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
|
unet_keys = model_lora_keys_unet(model.model) if model is not None else {}
|
||||||
@@ -32,23 +29,28 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
|
|||||||
if len(lora_unmatch) > 0:
|
if len(lora_unmatch) > 0:
|
||||||
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
|
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
|
||||||
|
|
||||||
if model is not None and len(lora_unet) > 0:
|
new_model = model.clone() if model is not None else None
|
||||||
loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model)
|
new_clip = clip.clone() if clip is not None else None
|
||||||
|
|
||||||
|
if new_model is not None and len(lora_unet) > 0:
|
||||||
|
loaded_keys = new_model.add_patches(filename=filename, patches=lora_unet, strength_patch=strength_model, online_mode=online_mode)
|
||||||
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
|
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
|
||||||
if len(skipped_keys) > 12:
|
if len(skipped_keys) > 12:
|
||||||
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||||
else:
|
else:
|
||||||
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
|
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}')
|
||||||
|
model = new_model
|
||||||
|
|
||||||
if clip is not None and len(lora_clip) > 0:
|
if new_clip is not None and len(lora_clip) > 0:
|
||||||
loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip)
|
loaded_keys = new_clip.add_patches(filename=filename, patches=lora_clip, strength_patch=strength_clip, online_mode=online_mode)
|
||||||
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
|
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
|
||||||
if len(skipped_keys) > 12:
|
if len(skipped_keys) > 12:
|
||||||
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||||
else:
|
else:
|
||||||
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
|
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys) with on_the_fly = {online_mode}')
|
||||||
|
clip = new_clip
|
||||||
|
|
||||||
return
|
return model, clip
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=5)
|
@functools.lru_cache(maxsize=5)
|
||||||
@@ -97,9 +99,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
network_on_disk.read_hash()
|
network_on_disk.read_hash()
|
||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
|
online_mode = dynamic_args.get('online_lora', False)
|
||||||
|
|
||||||
|
if current_sd.forge_objects.unet.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
online_mode = False
|
||||||
|
|
||||||
compiled_lora_targets = []
|
compiled_lora_targets = []
|
||||||
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
|
||||||
compiled_lora_targets.append([a.filename, b, c])
|
compiled_lora_targets.append([a.filename, b, c, online_mode])
|
||||||
|
|
||||||
compiled_lora_targets_hash = str(compiled_lora_targets)
|
compiled_lora_targets_hash = str(compiled_lora_targets)
|
||||||
|
|
||||||
@@ -107,15 +114,14 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
return
|
return
|
||||||
|
|
||||||
current_sd.current_lora_hash = compiled_lora_targets_hash
|
current_sd.current_lora_hash = compiled_lora_targets_hash
|
||||||
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone()
|
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
||||||
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone()
|
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
||||||
|
|
||||||
current_sd.forge_objects.unet.lora_loader.clear_patches()
|
for filename, strength_model, strength_clip, online_mode in compiled_lora_targets:
|
||||||
current_sd.forge_objects.clip.patcher.lora_loader.clear_patches()
|
|
||||||
|
|
||||||
for filename, strength_model, strength_clip in compiled_lora_targets:
|
|
||||||
lora_sd = load_lora_state_dict(filename)
|
lora_sd = load_lora_state_dict(filename)
|
||||||
load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename)
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
|
||||||
|
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
||||||
|
filename=filename, online_mode=online_mode)
|
||||||
|
|
||||||
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -802,8 +802,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
memory_management.unload_all_models()
|
memory_management.unload_all_models()
|
||||||
|
|
||||||
if need_global_unload:
|
if need_global_unload:
|
||||||
p.sd_model.current_lora_hash = str([])
|
|
||||||
p.sd_model.forge_objects.unet.lora_loader.dirty = True
|
|
||||||
p.clear_prompt_cache()
|
p.clear_prompt_cache()
|
||||||
|
|
||||||
need_global_unload = False
|
need_global_unload = False
|
||||||
|
|||||||
Reference in New Issue
Block a user