mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 00:49:48 +00:00
only load lora one time
This commit is contained in:
@@ -11,7 +11,7 @@ import inspect
|
||||
|
||||
from tqdm import tqdm
|
||||
from backend import memory_management, utils, operations
|
||||
from backend.patcher.lora import merge_lora_to_model_weight
|
||||
from backend.patcher.lora import merge_lora_to_model_weight, LoraLoader
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
@@ -54,14 +54,18 @@ class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options": {}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
|
||||
if not hasattr(model, 'lora_loader'):
|
||||
model.lora_loader = LoraLoader(model)
|
||||
|
||||
self.lora_loader: LoraLoader = model.lora_loader
|
||||
|
||||
if current_device is None:
|
||||
self.current_device = self.offload_device
|
||||
else:
|
||||
@@ -75,10 +79,6 @@ class ModelPatcher:
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
return n
|
||||
@@ -193,28 +193,6 @@ class ModelPatcher:
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
model_sd = self.model.state_dict()
|
||||
for k in patches:
|
||||
offset = None
|
||||
function = None
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||
self.patches[key] = current_patches
|
||||
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
memory_management.unload_model_clones(self)
|
||||
model_sd = self.model_state_dict()
|
||||
@@ -239,8 +217,6 @@ class ModelPatcher:
|
||||
return sd
|
||||
|
||||
def forge_patch_model(self, target_device=None):
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
for k, item in self.object_patches.items():
|
||||
old = utils.get_attr(self.model, k)
|
||||
|
||||
@@ -249,102 +225,21 @@ class ModelPatcher:
|
||||
|
||||
utils.set_attr_raw(self.model, k, item)
|
||||
|
||||
for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches):
|
||||
try:
|
||||
weight = utils.get_attr(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device)
|
||||
|
||||
bnb_layer = None
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
if hasattr(weight, 'bnb_quantized'):
|
||||
assert weight.module is not None, 'BNB bad weight without parent layer!'
|
||||
bnb_layer = weight.module
|
||||
if weight.bnb_quantized:
|
||||
weight_original_device = weight.device
|
||||
|
||||
if target_device is not None:
|
||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(target_device)
|
||||
else:
|
||||
weight = weight.cuda()
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
if target_device is None:
|
||||
weight = weight.to(device=weight_original_device)
|
||||
else:
|
||||
weight = weight.data
|
||||
|
||||
if target_device is not None:
|
||||
weight = weight.to(device=target_device)
|
||||
|
||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||
|
||||
if hasattr(weight, 'is_gguf'):
|
||||
from backend.operations_gguf import dequantize_tensor
|
||||
gguf_cls = weight.gguf_cls
|
||||
gguf_type = weight.gguf_type
|
||||
gguf_real_shape = weight.gguf_real_shape
|
||||
weight = dequantize_tensor(weight)
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
|
||||
if bnb_layer is not None:
|
||||
bnb_layer.reload_weight(weight)
|
||||
continue
|
||||
|
||||
if gguf_cls is not None:
|
||||
from backend.operations_gguf import ParameterGGUF
|
||||
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
|
||||
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
|
||||
data=weight,
|
||||
gguf_type=gguf_type,
|
||||
gguf_cls=gguf_cls,
|
||||
gguf_real_shape=gguf_real_shape
|
||||
))
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
||||
self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device)
|
||||
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
|
||||
if moving_time > 0.1:
|
||||
print(f'LoRA patching has taken {moving_time:.2f} seconds')
|
||||
|
||||
return self.model
|
||||
|
||||
def forge_unpatch_model(self, target_device=None):
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
for k in keys:
|
||||
w = self.backup[k]
|
||||
|
||||
if not isinstance(w, torch.nn.Parameter):
|
||||
# In very few cases
|
||||
w = torch.nn.Parameter(w, requires_grad=False)
|
||||
|
||||
utils.set_attr_raw(self.model, k, w)
|
||||
|
||||
self.backup = {}
|
||||
|
||||
if target_device is not None:
|
||||
self.model.to(target_device)
|
||||
self.current_device = target_device
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
|
||||
for k in keys:
|
||||
utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
|
||||
@@ -25,6 +25,3 @@ class CLIP:
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
return n
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui
|
||||
import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui
|
||||
|
||||
from backend import memory_management
|
||||
from tqdm import tqdm
|
||||
from backend import memory_management, utils, operations
|
||||
|
||||
|
||||
class ForgeLoraCollection:
|
||||
@@ -77,7 +79,7 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0].clone(), key),)
|
||||
v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
patch_type = ''
|
||||
|
||||
@@ -238,3 +240,140 @@ def merge_lora_to_model_weight(patches, weight, key):
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
class LoraLoader:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.dirty = False
|
||||
|
||||
def clear_patches(self):
|
||||
self.patches.clear()
|
||||
self.dirty = True
|
||||
return
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
p = set()
|
||||
model_sd = self.model.state_dict()
|
||||
|
||||
for k in patches:
|
||||
offset = None
|
||||
function = None
|
||||
|
||||
if isinstance(k, str):
|
||||
key = k
|
||||
else:
|
||||
offset = k[1]
|
||||
key = k[0]
|
||||
if len(k) > 2:
|
||||
function = k[2]
|
||||
|
||||
if key in model_sd:
|
||||
p.add(k)
|
||||
current_patches = self.patches.get(key, [])
|
||||
current_patches.append((strength_patch, patches[k], strength_model, offset, function))
|
||||
self.patches[key] = current_patches
|
||||
|
||||
self.dirty = True
|
||||
return list(p)
|
||||
|
||||
def refresh(self, target_device=None, offload_device=torch.cpu):
|
||||
if not self.dirty:
|
||||
return
|
||||
|
||||
self.dirty = False
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
# Restore
|
||||
|
||||
for k, w in self.backup.items():
|
||||
if target_device is not None:
|
||||
w = w.to(device=target_device)
|
||||
|
||||
if not isinstance(w, torch.nn.Parameter):
|
||||
# In very few cases
|
||||
w = torch.nn.Parameter(w, requires_grad=False)
|
||||
|
||||
utils.set_attr_raw(self.model, k, w)
|
||||
|
||||
self.backup = {}
|
||||
|
||||
# Patch
|
||||
|
||||
for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches):
|
||||
try:
|
||||
weight = utils.get_attr(self.model, key)
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
except:
|
||||
raise ValueError(f"Wrong LoRA Key: {key}")
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=offload_device)
|
||||
|
||||
bnb_layer = None
|
||||
|
||||
if operations.bnb_avaliable:
|
||||
if hasattr(weight, 'bnb_quantized'):
|
||||
assert weight.module is not None, 'BNB bad weight without parent layer!'
|
||||
bnb_layer = weight.module
|
||||
if weight.bnb_quantized:
|
||||
weight_original_device = weight.device
|
||||
|
||||
if target_device is not None:
|
||||
assert target_device.type == 'cuda', 'BNB Must use CUDA!'
|
||||
weight = weight.to(target_device)
|
||||
else:
|
||||
weight = weight.cuda()
|
||||
|
||||
from backend.operations_bnb import functional_dequantize_4bit
|
||||
weight = functional_dequantize_4bit(weight)
|
||||
|
||||
if target_device is None:
|
||||
weight = weight.to(device=weight_original_device)
|
||||
else:
|
||||
weight = weight.data
|
||||
|
||||
if target_device is not None:
|
||||
weight = weight.to(device=target_device)
|
||||
|
||||
gguf_cls, gguf_type, gguf_real_shape = None, None, None
|
||||
|
||||
if hasattr(weight, 'is_gguf'):
|
||||
from backend.operations_gguf import dequantize_tensor
|
||||
gguf_cls = weight.gguf_cls
|
||||
gguf_type = weight.gguf_type
|
||||
gguf_real_shape = weight.gguf_real_shape
|
||||
weight = dequantize_tensor(weight)
|
||||
|
||||
weight_original_dtype = weight.dtype
|
||||
weight = weight.to(dtype=torch.float32)
|
||||
weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype)
|
||||
|
||||
if bnb_layer is not None:
|
||||
bnb_layer.reload_weight(weight)
|
||||
continue
|
||||
|
||||
if gguf_cls is not None:
|
||||
from backend.operations_gguf import ParameterGGUF
|
||||
weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape)
|
||||
utils.set_attr_raw(self.model, key, ParameterGGUF.make(
|
||||
data=weight,
|
||||
gguf_type=gguf_type,
|
||||
gguf_cls=gguf_cls,
|
||||
gguf_real_shape=gguf_real_shape
|
||||
))
|
||||
continue
|
||||
|
||||
utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False))
|
||||
|
||||
# Time
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
|
||||
if moving_time > 0.1:
|
||||
print(f'LoRA patching has taken {moving_time:.2f} seconds')
|
||||
|
||||
return
|
||||
|
||||
@@ -25,11 +25,6 @@ class UnetPatcher(ModelPatcher):
|
||||
|
||||
def clone(self):
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
|
||||
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.controlnet_linked_list = self.controlnet_linked_list
|
||||
@@ -196,5 +191,5 @@ class UnetPatcher(ModelPatcher):
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
self.lora_loader.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
return
|
||||
|
||||
@@ -32,28 +32,23 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
|
||||
if len(lora_unmatch) > 0:
|
||||
print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')
|
||||
|
||||
new_model = model.clone() if model is not None else None
|
||||
new_clip = clip.clone() if clip is not None else None
|
||||
|
||||
if new_model is not None and len(lora_unet) > 0:
|
||||
loaded_keys = new_model.add_patches(lora_unet, strength_model)
|
||||
if model is not None and len(lora_unet) > 0:
|
||||
loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model)
|
||||
skipped_keys = [item for item in lora_unet if item not in loaded_keys]
|
||||
if len(skipped_keys) > 12:
|
||||
print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||
else:
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
|
||||
model = new_model
|
||||
|
||||
if new_clip is not None and len(lora_clip) > 0:
|
||||
loaded_keys = new_clip.add_patches(lora_clip, strength_clip)
|
||||
if clip is not None and len(lora_clip) > 0:
|
||||
loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip)
|
||||
skipped_keys = [item for item in lora_clip if item not in loaded_keys]
|
||||
if len(skipped_keys) > 12:
|
||||
print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys')
|
||||
else:
|
||||
print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
|
||||
clip = new_clip
|
||||
|
||||
return model, clip
|
||||
return
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=5)
|
||||
@@ -112,14 +107,15 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
return
|
||||
|
||||
current_sd.current_lora_hash = compiled_lora_targets_hash
|
||||
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
|
||||
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
|
||||
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone()
|
||||
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone()
|
||||
|
||||
current_sd.forge_objects.unet.lora_loader.clear_patches()
|
||||
current_sd.forge_objects.clip.patcher.lora_loader.clear_patches()
|
||||
|
||||
for filename, strength_model, strength_clip in compiled_lora_targets:
|
||||
lora_sd = load_lora_state_dict(filename)
|
||||
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
|
||||
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
|
||||
filename=filename)
|
||||
load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename)
|
||||
|
||||
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user