mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
245 lines
9.4 KiB
Python
245 lines
9.4 KiB
Python
# Model Patching System, Copyright Forge 2024
|
|
|
|
# API Templates partially extracted From ComfyUI, the actual implementation for those APIs
|
|
# are from Forge, implemented from scratch (after forge-v1.0.1), and may have
|
|
# certain level of differences.
|
|
|
|
|
|
import copy
|
|
import inspect
|
|
|
|
from backend import memory_management, utils
|
|
from backend.patcher.lora import LoraLoader
|
|
|
|
|
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
|
to = model_options["transformer_options"].copy()
|
|
|
|
if "patches_replace" not in to:
|
|
to["patches_replace"] = {}
|
|
else:
|
|
to["patches_replace"] = to["patches_replace"].copy()
|
|
|
|
if name not in to["patches_replace"]:
|
|
to["patches_replace"][name] = {}
|
|
else:
|
|
to["patches_replace"][name] = to["patches_replace"][name].copy()
|
|
|
|
if transformer_index is not None:
|
|
block = (block_name, number, transformer_index)
|
|
else:
|
|
block = (block_name, number)
|
|
to["patches_replace"][name][block] = patch
|
|
model_options["transformer_options"] = to
|
|
return model_options
|
|
|
|
|
|
def set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=False):
|
|
model_options["sampler_post_cfg_function"] = model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
model_options["disable_cfg1_optimization"] = True
|
|
return model_options
|
|
|
|
|
|
def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False):
|
|
model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function]
|
|
if disable_cfg1_optimization:
|
|
model_options["disable_cfg1_optimization"] = True
|
|
return model_options
|
|
|
|
|
|
class ModelPatcher:
|
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
|
self.size = size
|
|
self.model = model
|
|
self.object_patches = {}
|
|
self.object_patches_backup = {}
|
|
self.model_options = {"transformer_options": {}}
|
|
self.model_size()
|
|
self.load_device = load_device
|
|
self.offload_device = offload_device
|
|
|
|
if not hasattr(model, 'lora_loader'):
|
|
model.lora_loader = LoraLoader(model)
|
|
|
|
self.lora_loader: LoraLoader = model.lora_loader
|
|
|
|
if current_device is None:
|
|
self.current_device = self.offload_device
|
|
else:
|
|
self.current_device = current_device
|
|
|
|
def model_size(self):
|
|
if self.size > 0:
|
|
return self.size
|
|
self.size = memory_management.module_size(self.model)
|
|
return self.size
|
|
|
|
def clone(self):
|
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
|
n.object_patches = self.object_patches.copy()
|
|
n.model_options = copy.deepcopy(self.model_options)
|
|
return n
|
|
|
|
def is_clone(self, other):
|
|
if hasattr(other, 'model') and self.model is other.model:
|
|
return True
|
|
return False
|
|
|
|
def memory_required(self, input_shape):
|
|
return self.model.memory_required(input_shape=input_shape)
|
|
|
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) # Old way
|
|
else:
|
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
|
if disable_cfg1_optimization:
|
|
self.model_options["disable_cfg1_optimization"] = True
|
|
|
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
|
|
|
def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False):
|
|
self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization)
|
|
|
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
|
|
|
def set_model_vae_encode_wrapper(self, wrapper_function):
|
|
self.model_options["model_vae_encode_wrapper"] = wrapper_function
|
|
|
|
def set_model_vae_decode_wrapper(self, wrapper_function):
|
|
self.model_options["model_vae_decode_wrapper"] = wrapper_function
|
|
|
|
def set_model_vae_regulation(self, vae_regulation):
|
|
self.model_options["model_vae_regulation"] = vae_regulation
|
|
|
|
def set_model_denoise_mask_function(self, denoise_mask_function):
|
|
self.model_options["denoise_mask_function"] = denoise_mask_function
|
|
|
|
def set_model_patch(self, patch, name):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" not in to:
|
|
to["patches"] = {}
|
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
|
|
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
|
self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index)
|
|
|
|
def set_model_attn1_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_patch")
|
|
|
|
def set_model_attn2_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_patch")
|
|
|
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
|
|
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
|
|
|
def set_model_attn1_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn1_output_patch")
|
|
|
|
def set_model_attn2_output_patch(self, patch):
|
|
self.set_model_patch(patch, "attn2_output_patch")
|
|
|
|
def set_model_input_block_patch(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch")
|
|
|
|
def set_model_input_block_patch_after_skip(self, patch):
|
|
self.set_model_patch(patch, "input_block_patch_after_skip")
|
|
|
|
def set_model_output_block_patch(self, patch):
|
|
self.set_model_patch(patch, "output_block_patch")
|
|
|
|
def add_object_patch(self, name, obj):
|
|
self.object_patches[name] = obj
|
|
|
|
def get_model_object(self, name):
|
|
if name in self.object_patches:
|
|
return self.object_patches[name]
|
|
else:
|
|
if name in self.object_patches_backup:
|
|
return self.object_patches_backup[name]
|
|
else:
|
|
return utils.get_attr(self.model, name)
|
|
|
|
def model_patches_to(self, device):
|
|
to = self.model_options["transformer_options"]
|
|
if "patches" in to:
|
|
patches = to["patches"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for i in range(len(patch_list)):
|
|
if hasattr(patch_list[i], "to"):
|
|
patch_list[i] = patch_list[i].to(device)
|
|
if "patches_replace" in to:
|
|
patches = to["patches_replace"]
|
|
for name in patches:
|
|
patch_list = patches[name]
|
|
for k in patch_list:
|
|
if hasattr(patch_list[k], "to"):
|
|
patch_list[k] = patch_list[k].to(device)
|
|
if "model_function_wrapper" in self.model_options:
|
|
wrap_func = self.model_options["model_function_wrapper"]
|
|
if hasattr(wrap_func, "to"):
|
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
|
|
|
def model_dtype(self):
|
|
if hasattr(self.model, "get_dtype"):
|
|
return self.model.get_dtype()
|
|
|
|
def get_key_patches(self, filter_prefix=None):
|
|
memory_management.unload_model_clones(self)
|
|
model_sd = self.model_state_dict()
|
|
p = {}
|
|
for k in model_sd:
|
|
if filter_prefix is not None:
|
|
if not k.startswith(filter_prefix):
|
|
continue
|
|
if k in self.patches:
|
|
p[k] = [model_sd[k]] + self.patches[k]
|
|
else:
|
|
p[k] = (model_sd[k],)
|
|
return p
|
|
|
|
def model_state_dict(self, filter_prefix=None):
|
|
sd = self.model.state_dict()
|
|
keys = list(sd.keys())
|
|
if filter_prefix is not None:
|
|
for k in keys:
|
|
if not k.startswith(filter_prefix):
|
|
sd.pop(k)
|
|
return sd
|
|
|
|
def forge_patch_model(self, target_device=None):
|
|
for k, item in self.object_patches.items():
|
|
old = utils.get_attr(self.model, k)
|
|
|
|
if k not in self.object_patches_backup:
|
|
self.object_patches_backup[k] = old
|
|
|
|
utils.set_attr_raw(self.model, k, item)
|
|
|
|
self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device)
|
|
|
|
if target_device is not None:
|
|
self.model.to(target_device)
|
|
|
|
return self.model
|
|
|
|
def forge_unpatch_model(self, target_device=None):
|
|
if target_device is not None:
|
|
self.model.to(target_device)
|
|
self.current_device = target_device
|
|
|
|
keys = list(self.object_patches_backup.keys())
|
|
|
|
for k in keys:
|
|
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
|
|
|
self.object_patches_backup = {}
|
|
return
|