diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 3b41198f..f60e60b7 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -11,7 +11,7 @@ import inspect from tqdm import tqdm from backend import memory_management, utils, operations -from backend.patcher.lora import merge_lora_to_model_weight +from backend.patcher.lora import merge_lora_to_model_weight, LoraLoader def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): @@ -54,14 +54,18 @@ class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, **kwargs): self.size = size self.model = model - self.patches = {} - self.backup = {} self.object_patches = {} self.object_patches_backup = {} self.model_options = {"transformer_options": {}} self.model_size() self.load_device = load_device self.offload_device = offload_device + + if not hasattr(model, 'lora_loader'): + model.lora_loader = LoraLoader(model) + + self.lora_loader: LoraLoader = model.lora_loader + if current_device is None: self.current_device = self.offload_device else: @@ -75,10 +79,6 @@ class ModelPatcher: def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) return n @@ -193,28 +193,6 @@ class ModelPatcher: if hasattr(self.model, "get_dtype"): return self.model.get_dtype() - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - self.patches[key] = current_patches - - return list(p) - def get_key_patches(self, filter_prefix=None): memory_management.unload_model_clones(self) model_sd = self.model_state_dict() @@ -239,8 +217,6 @@ class ModelPatcher: return sd def forge_patch_model(self, target_device=None): - execution_start_time = time.perf_counter() - for k, item in self.object_patches.items(): old = utils.get_attr(self.model, k) @@ -249,102 +225,21 @@ class ModelPatcher: utils.set_attr_raw(self.model, k, item) - for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches): - try: - weight = utils.get_attr(self.model, key) - assert isinstance(weight, torch.nn.Parameter) - except: - raise ValueError(f"Wrong LoRA Key: {key}") - - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device) - - bnb_layer = None - - if operations.bnb_avaliable: - if hasattr(weight, 'bnb_quantized'): - assert weight.module is not None, 'BNB bad weight without parent layer!' - bnb_layer = weight.module - if weight.bnb_quantized: - weight_original_device = weight.device - - if target_device is not None: - assert target_device.type == 'cuda', 'BNB Must use CUDA!' - weight = weight.to(target_device) - else: - weight = weight.cuda() - - from backend.operations_bnb import functional_dequantize_4bit - weight = functional_dequantize_4bit(weight) - - if target_device is None: - weight = weight.to(device=weight_original_device) - else: - weight = weight.data - - if target_device is not None: - weight = weight.to(device=target_device) - - gguf_cls, gguf_type, gguf_real_shape = None, None, None - - if hasattr(weight, 'is_gguf'): - from backend.operations_gguf import dequantize_tensor - gguf_cls = weight.gguf_cls - gguf_type = weight.gguf_type - gguf_real_shape = weight.gguf_real_shape - weight = dequantize_tensor(weight) - - weight_original_dtype = weight.dtype - weight = weight.to(dtype=torch.float32) - weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) - - if bnb_layer is not None: - bnb_layer.reload_weight(weight) - continue - - if gguf_cls is not None: - from backend.operations_gguf import ParameterGGUF - weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape) - utils.set_attr_raw(self.model, key, ParameterGGUF.make( - data=weight, - gguf_type=gguf_type, - gguf_cls=gguf_cls, - gguf_real_shape=gguf_real_shape - )) - continue - - utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) + self.lora_loader.refresh(target_device=target_device, offload_device=self.offload_device) if target_device is not None: self.model.to(target_device) self.current_device = target_device - moving_time = time.perf_counter() - execution_start_time - - if moving_time > 0.1: - print(f'LoRA patching has taken {moving_time:.2f} seconds') - return self.model def forge_unpatch_model(self, target_device=None): - keys = list(self.backup.keys()) - - for k in keys: - w = self.backup[k] - - if not isinstance(w, torch.nn.Parameter): - # In very few cases - w = torch.nn.Parameter(w, requires_grad=False) - - utils.set_attr_raw(self.model, k, w) - - self.backup = {} - if target_device is not None: self.model.to(target_device) self.current_device = target_device keys = list(self.object_patches_backup.keys()) + for k in keys: utils.set_attr_raw(self.model, k, self.object_patches_backup[k]) diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py index 3979e96c..bb74ba13 100644 --- a/backend/patcher/clip.py +++ b/backend/patcher/clip.py @@ -25,6 +25,3 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer return n - - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - return self.patcher.add_patches(patches, strength_patch, strength_model) diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index ea49acbe..782a7b04 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -1,9 +1,11 @@ import torch +import time import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui -from backend import memory_management +from tqdm import tqdm +from backend import memory_management, utils, operations class ForgeLoraCollection: @@ -77,7 +79,7 @@ def merge_lora_to_model_weight(patches, weight, key): weight *= strength_model if isinstance(v, list): - v = (calculate_weight(v[1:], v[0].clone(), key),) + v = (merge_lora_to_model_weight(v[1:], v[0].clone(), key),) patch_type = '' @@ -238,3 +240,140 @@ def merge_lora_to_model_weight(patches, weight, key): weight = old_weight return weight + + +class LoraLoader: + def __init__(self, model): + self.model = model + self.patches = {} + self.backup = {} + self.dirty = False + + def clear_patches(self): + self.patches.clear() + self.dirty = True + return + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + p = set() + model_sd = self.model.state_dict() + + for k in patches: + offset = None + function = None + + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] + + if key in model_sd: + p.add(k) + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + self.patches[key] = current_patches + + self.dirty = True + return list(p) + + def refresh(self, target_device=None, offload_device=torch.cpu): + if not self.dirty: + return + + self.dirty = False + + execution_start_time = time.perf_counter() + + # Restore + + for k, w in self.backup.items(): + if target_device is not None: + w = w.to(device=target_device) + + if not isinstance(w, torch.nn.Parameter): + # In very few cases + w = torch.nn.Parameter(w, requires_grad=False) + + utils.set_attr_raw(self.model, k, w) + + self.backup = {} + + # Patch + + for key, current_patches in (tqdm(self.patches.items(), desc='Patching LoRAs') if len(self.patches) > 0 else self.patches): + try: + weight = utils.get_attr(self.model, key) + assert isinstance(weight, torch.nn.Parameter) + except: + raise ValueError(f"Wrong LoRA Key: {key}") + + if key not in self.backup: + self.backup[key] = weight.to(device=offload_device) + + bnb_layer = None + + if operations.bnb_avaliable: + if hasattr(weight, 'bnb_quantized'): + assert weight.module is not None, 'BNB bad weight without parent layer!' + bnb_layer = weight.module + if weight.bnb_quantized: + weight_original_device = weight.device + + if target_device is not None: + assert target_device.type == 'cuda', 'BNB Must use CUDA!' + weight = weight.to(target_device) + else: + weight = weight.cuda() + + from backend.operations_bnb import functional_dequantize_4bit + weight = functional_dequantize_4bit(weight) + + if target_device is None: + weight = weight.to(device=weight_original_device) + else: + weight = weight.data + + if target_device is not None: + weight = weight.to(device=target_device) + + gguf_cls, gguf_type, gguf_real_shape = None, None, None + + if hasattr(weight, 'is_gguf'): + from backend.operations_gguf import dequantize_tensor + gguf_cls = weight.gguf_cls + gguf_type = weight.gguf_type + gguf_real_shape = weight.gguf_real_shape + weight = dequantize_tensor(weight) + + weight_original_dtype = weight.dtype + weight = weight.to(dtype=torch.float32) + weight = merge_lora_to_model_weight(current_patches, weight, key).to(dtype=weight_original_dtype) + + if bnb_layer is not None: + bnb_layer.reload_weight(weight) + continue + + if gguf_cls is not None: + from backend.operations_gguf import ParameterGGUF + weight = gguf_cls.quantize_pytorch(weight, gguf_real_shape) + utils.set_attr_raw(self.model, key, ParameterGGUF.make( + data=weight, + gguf_type=gguf_type, + gguf_cls=gguf_cls, + gguf_real_shape=gguf_real_shape + )) + continue + + utils.set_attr_raw(self.model, key, torch.nn.Parameter(weight, requires_grad=False)) + + # Time + + moving_time = time.perf_counter() - execution_start_time + + if moving_time > 0.1: + print(f'LoRA patching has taken {moving_time:.2f} seconds') + + return diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py index eeaea671..060680c9 100644 --- a/backend/patcher/unet.py +++ b/backend/patcher/unet.py @@ -25,11 +25,6 @@ class UnetPatcher(ModelPatcher): def clone(self): n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) - - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - n.object_patches = self.object_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.controlnet_linked_list = self.controlnet_linked_list @@ -196,5 +191,5 @@ class UnetPatcher(ModelPatcher): for patch_type, weight_list in v.items(): patch_flat[model_key] = (patch_type, weight_list) - self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) + self.lora_loader.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) return diff --git a/extensions-builtin/sd_forge_lora/networks.py b/extensions-builtin/sd_forge_lora/networks.py index c22ab01f..55096fd8 100644 --- a/extensions-builtin/sd_forge_lora/networks.py +++ b/extensions-builtin/sd_forge_lora/networks.py @@ -32,28 +32,23 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen if len(lora_unmatch) > 0: print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}') - new_model = model.clone() if model is not None else None - new_clip = clip.clone() if clip is not None else None - - if new_model is not None and len(lora_unet) > 0: - loaded_keys = new_model.add_patches(lora_unet, strength_model) + if model is not None and len(lora_unet) > 0: + loaded_keys = model.lora_loader.add_patches(lora_unet, strength_model) skipped_keys = [item for item in lora_unet if item not in loaded_keys] if len(skipped_keys) > 12: print(f'[LORA] Mismatch {filename} for {model_flag}-UNet with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') else: print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)') - model = new_model - if new_clip is not None and len(lora_clip) > 0: - loaded_keys = new_clip.add_patches(lora_clip, strength_clip) + if clip is not None and len(lora_clip) > 0: + loaded_keys = clip.patcher.lora_loader.add_patches(lora_clip, strength_clip) skipped_keys = [item for item in lora_clip if item not in loaded_keys] if len(skipped_keys) > 12: print(f'[LORA] Mismatch {filename} for {model_flag}-CLIP with {len(skipped_keys)} keys mismatched in {len(loaded_keys)} keys') else: print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)') - clip = new_clip - return model, clip + return @functools.lru_cache(maxsize=5) @@ -112,14 +107,15 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No return current_sd.current_lora_hash = compiled_lora_targets_hash - current_sd.forge_objects.unet = current_sd.forge_objects_original.unet - current_sd.forge_objects.clip = current_sd.forge_objects_original.clip + current_sd.forge_objects.unet = current_sd.forge_objects_original.unet.clone() + current_sd.forge_objects.clip = current_sd.forge_objects_original.clip.clone() + + current_sd.forge_objects.unet.lora_loader.clear_patches() + current_sd.forge_objects.clip.patcher.lora_loader.clear_patches() for filename, strength_model, strength_clip in compiled_lora_targets: lora_sd = load_lora_state_dict(filename) - current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( - current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, - filename=filename) + load_lora_for_models(current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip, filename=filename) current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() return