mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
revise lora patching
This commit is contained in:
@@ -68,8 +68,13 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, computation
|
||||
def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=torch.float32):
|
||||
# Modified from https://github.com/comfyanonymous/ComfyUI/blob/39f114c44bb99d4a221e8da451d4f2a20119c674/comfy/model_patcher.py#L446
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=computation_dtype)
|
||||
weight_dtype_backup = None
|
||||
|
||||
if computation_dtype == weight.dtype:
|
||||
weight = weight.clone()
|
||||
else:
|
||||
weight_dtype_backup = weight.dtype
|
||||
weight = weight.to(dtype=computation_dtype)
|
||||
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
@@ -253,7 +258,9 @@ def merge_lora_to_weight(patches, weight, key="online_lora", computation_dtype=t
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
weight = weight.to(dtype=weight_original_dtype)
|
||||
if weight_dtype_backup is not None:
|
||||
weight = weight.to(dtype=weight_dtype_backup)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
@@ -267,6 +274,7 @@ class LoraLoader:
|
||||
self.backup = {}
|
||||
self.online_backup = []
|
||||
self.dirty = False
|
||||
self.online_mode = False
|
||||
|
||||
def clear_patches(self):
|
||||
self.patches.clear()
|
||||
@@ -296,6 +304,13 @@ class LoraLoader:
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.dirty = True
|
||||
|
||||
self.online_mode = dynamic_args.get('online_lora', False)
|
||||
|
||||
if hasattr(self.model, 'storage_dtype'):
|
||||
if self.model.storage_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
self.online_mode = False
|
||||
|
||||
return list(p)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -323,7 +338,11 @@ class LoraLoader:
|
||||
|
||||
self.backup = {}
|
||||
|
||||
online_mode = dynamic_args.get('online_lora', False)
|
||||
if len(self.patches) > 0:
|
||||
if self.online_mode:
|
||||
print('Patching LoRA in on-the-fly.')
|
||||
else:
|
||||
print('Patching LoRA by precomputing model weights.')
|
||||
|
||||
# Patch
|
||||
|
||||
@@ -337,7 +356,7 @@ class LoraLoader:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=offload_device)
|
||||
|
||||
if online_mode:
|
||||
if self.online_mode:
|
||||
if not hasattr(parent_layer, 'forge_online_loras'):
|
||||
parent_layer.forge_online_loras = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user