mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 17:09:59 +00:00
revise lora patching
This commit is contained in:
@@ -68,8 +68,13 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation
|
||||
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=computation_dtype)
|
||||
weight_dtype_backup = None
|
||||
|
||||
if computation_dtype == weight.dtype:
|
||||
weight = weight.clone()
|
||||
else:
|
||||
weight_dtype_backup = weight.dtype
|
||||
weight = weight.to(dtype=computation_dtype)
|
||||
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
@@ -253,7 +258,9 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
weight = weight.to(dtype=weight_original_dtype)
|
||||
if weight_dtype_backup is not None:
|
||||
weight = weight.to(dtype=weight_dtype_backup)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
@@ -267,6 +274,7 @@ class LoraLoader:
|
||||
self.backup = {}
|
||||
self.online_backup = []
|
||||
self.dirty = False
|
||||
self.online_mode = False
|
||||
|
||||
def clear_patches(self):
|
||||
self.patches.clear()
|
||||
@@ -296,6 +304,13 @@ class LoraLoader:
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.dirty = True
|
||||
|
||||
self.online_mode = dynamic_args.get('online_lora', False)
|
||||
|
||||
if hasattr(self.model, 'storage_dtype'):
|
||||
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
self.online_mode = False
|
||||
|
||||
return list(p)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -323,7 +338,11 @@ class LoraLoader:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
online_mode = dynamic_args.get('online_lora', False)
|
||||
if len(self.patches) > 0:
|
||||
if self.online_mode:
|
||||
print('Patching LoRA in on-the-fly.')
|
||||
else:
|
||||
print('Patching LoRA by precomputing model weights.')
|
||||
|
||||
# Patch
|
||||
|
||||
@@ -337,7 +356,7 @@ class LoraLoader:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=offload_device)
|
||||
|
||||
if online_mode:
|
||||
if self.online_mode:
|
||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||
parent_layer.forge_online_loras = {}
|
||||
|
||||
|
||||
@@ -376,7 +376,7 @@ def sampling_prepare(unet, x):
|
||||
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
||||
|
||||
if dynamic_args.get('online_lora', False):
|
||||
if unet.lora_loader.online_mode:
|
||||
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
||||
additional_inference_memory += lora_memory
|
||||
|
||||
@@ -384,11 +384,9 @@ def sampling_prepare(unet, x):
|
||||
models=[unet] + additional_model_patchers,
|
||||
memory_required=unet_inference_memory + additional_inference_memory)
|
||||
|
||||
if dynamic_args.get('online_lora', False):
|
||||
if unet.lora_loader.online_mode:
|
||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
||||
|
||||
unet.lora_loader.patches = {}
|
||||
|
||||
real_model = unet.model
|
||||
|
||||
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
||||
@@ -400,6 +398,8 @@ def sampling_prepare(unet, x):
|
||||
|
||||
|
||||
def sampling_cleanup(unet):
|
||||
if unet.lora_loader.online_mode:
|
||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.cleanup()
|
||||
cleanup_cache()
|
||||
|
||||
Reference in New Issue
Block a user