mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
revise lora patching
This commit is contained in:
@@ -68,8 +68,13 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation
|
|||||||
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
||||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
||||||
|
|
||||||
weight_original_dtype = weight.dtype
|
weight_dtype_backup = None
|
||||||
weight = weight.to(dtype=computation_dtype)
|
|
||||||
|
if computation_dtype == weight.dtype:
|
||||||
|
weight = weight.clone()
|
||||||
|
else:
|
||||||
|
weight_dtype_backup = weight.dtype
|
||||||
|
weight = weight.to(dtype=computation_dtype)
|
||||||
|
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
@@ -253,7 +258,9 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
|||||||
if old_weight is not None:
|
if old_weight is not None:
|
||||||
weight = old_weight
|
weight = old_weight
|
||||||
|
|
||||||
weight = weight.to(dtype=weight_original_dtype)
|
if weight_dtype_backup is not None:
|
||||||
|
weight = weight.to(dtype=weight_dtype_backup)
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
@@ -267,6 +274,7 @@ class LoraLoader:
|
|||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.online_backup = []
|
self.online_backup = []
|
||||||
self.dirty = False
|
self.dirty = False
|
||||||
|
self.online_mode = False
|
||||||
|
|
||||||
def clear_patches(self):
|
def clear_patches(self):
|
||||||
self.patches.clear()
|
self.patches.clear()
|
||||||
@@ -296,6 +304,13 @@ class LoraLoader:
|
|||||||
self.patches[key] = current_patches
|
self.patches[key] = current_patches
|
||||||
|
|
||||||
self.dirty = True
|
self.dirty = True
|
||||||
|
|
||||||
|
self.online_mode = dynamic_args.get('online_lora', False)
|
||||||
|
|
||||||
|
if hasattr(self.model, 'storage_dtype'):
|
||||||
|
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
self.online_mode = False
|
||||||
|
|
||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -323,7 +338,11 @@ class LoraLoader:
|
|||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
online_mode = dynamic_args.get('online_lora', False)
|
if len(self.patches) > 0:
|
||||||
|
if self.online_mode:
|
||||||
|
print('Patching LoRA in on-the-fly.')
|
||||||
|
else:
|
||||||
|
print('Patching LoRA by precomputing model weights.')
|
||||||
|
|
||||||
# Patch
|
# Patch
|
||||||
|
|
||||||
@@ -337,7 +356,7 @@ class LoraLoader:
|
|||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=offload_device)
|
self.backup[key] = weight.to(device=offload_device)
|
||||||
|
|
||||||
if online_mode:
|
if self.online_mode:
|
||||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||||
parent_layer.forge_online_loras = {}
|
parent_layer.forge_online_loras = {}
|
||||||
|
|
||||||
|
|||||||
@@ -376,7 +376,7 @@ def sampling_prepare(unet, x):
|
|||||||
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
||||||
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
||||||
|
|
||||||
if dynamic_args.get('online_lora', False):
|
if unet.lora_loader.online_mode:
|
||||||
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
||||||
additional_inference_memory += lora_memory
|
additional_inference_memory += lora_memory
|
||||||
|
|
||||||
@@ -384,11 +384,9 @@ def sampling_prepare(unet, x):
|
|||||||
models=[unet] + additional_model_patchers,
|
models=[unet] + additional_model_patchers,
|
||||||
memory_required=unet_inference_memory + additional_inference_memory)
|
memory_required=unet_inference_memory + additional_inference_memory)
|
||||||
|
|
||||||
if dynamic_args.get('online_lora', False):
|
if unet.lora_loader.online_mode:
|
||||||
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
||||||
|
|
||||||
unet.lora_loader.patches = {}
|
|
||||||
|
|
||||||
real_model = unet.model
|
real_model = unet.model
|
||||||
|
|
||||||
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
||||||
@@ -400,6 +398,8 @@ def sampling_prepare(unet, x):
|
|||||||
|
|
||||||
|
|
||||||
def sampling_cleanup(unet):
|
def sampling_cleanup(unet):
|
||||||
|
if unet.lora_loader.online_mode:
|
||||||
|
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
|
||||||
for cnet in unet.list_controlnets():
|
for cnet in unet.list_controlnets():
|
||||||
cnet.cleanup()
|
cnet.cleanup()
|
||||||
cleanup_cache()
|
cleanup_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user