revise lora patching

This commit is contained in:
layerdiffusion
2024-08-22 11:59:43 -07:00
parent 28ad046447
commit 2ab19f7f1c
2 changed files with 28 additions and 9 deletions

View File

@@ -68,8 +68,13 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
weight_original_dtype = weight.dtype
weight = weight.to(dtype=computation_dtype)
weight_dtype_backup = None
if computation_dtype == weight.dtype:
weight = weight.clone()
else:
weight_dtype_backup = weight.dtype
weight = weight.to(dtype=computation_dtype)
for p in patches:
strength = p[0]
@@ -253,7 +258,9 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
if old_weight is not None:
weight = old_weight
weight = weight.to(dtype=weight_original_dtype)
if weight_dtype_backup is not None:
weight = weight.to(dtype=weight_dtype_backup)
return weight
@@ -267,6 +274,7 @@ class LoraLoader:
self.backup = {}
self.online_backup = []
self.dirty = False
self.online_mode = False
def clear_patches(self):
self.patches.clear()
@@ -296,6 +304,13 @@ class LoraLoader:
self.patches[key] = current_patches
self.dirty = True
self.online_mode = dynamic_args.get('online_lora', False)
if hasattr(self.model, 'storage_dtype'):
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
self.online_mode = False
return list(p)
@torch.inference_mode()
@@ -323,7 +338,11 @@ class LoraLoader:
self.backup = {}
online_mode = dynamic_args.get('online_lora', False)
if len(self.patches) > 0:
if self.online_mode:
print('Patching LoRA in on-the-fly.')
else:
print('Patching LoRA by precomputing model weights.')
# Patch
@@ -337,7 +356,7 @@ class LoraLoader:
if key not in self.backup:
self.backup[key] = weight.to(device=offload_device)
if online_mode:
if self.online_mode:
if not hasattr(parent_layer, 'forge_online_loras'):
parent_layer.forge_online_loras = {}

View File

@@ -376,7 +376,7 @@ def sampling_prepare(unet, x):
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
additional_model_patchers += unet.controlnet_linked_list.get_models()
if dynamic_args.get('online_lora', False):
if unet.lora_loader.online_mode:
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
additional_inference_memory += lora_memory
@@ -384,11 +384,9 @@ def sampling_prepare(unet, x):
models=[unet] + additional_model_patchers,
memory_required=unet_inference_memory + additional_inference_memory)
if dynamic_args.get('online_lora', False):
if unet.lora_loader.online_mode:
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
unet.lora_loader.patches = {}
real_model = unet.model
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
@@ -400,6 +398,8 @@ def sampling_prepare(unet, x):
def sampling_cleanup(unet):
if unet.lora_loader.online_mode:
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.offload_device)
for cnet in unet.list_controlnets():
cnet.cleanup()
cleanup_cache()