diff --git a/backend/patcher/base.py b/backend/patcher/base.py index 2b12ac07..5d9b6cb7 100644 --- a/backend/patcher/base.py +++ b/backend/patcher/base.py @@ -7,29 +7,7 @@ import copy import inspect from backend import memory_management, utils - -extra_weight_calculators = {} - - -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): - dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32) - lora_diff *= alpha - weight_calc = weight + lora_diff.type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * weight_calc - else: - weight[:] = weight_calc - return weight +from backend.patcher.lora import merge_lora_to_model_weight def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): @@ -258,225 +236,43 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self, device_to=None, patch_weights=True): + def patch_model(self, device_to=None): for k in self.object_patches: old = utils.get_attr(self.model, k) if k not in self.object_patches_backup: self.object_patches_backup[k] = old utils.set_attr_raw(self.model, k, self.object_patches[k]) - if patch_weights: - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue + model_state_dict = self.model_state_dict() - weight = model_sd[key] + for key, current_patches in self.patches.items(): + assert key in model_state_dict, f"Wrong LoRA Key: {key}" - inplace_update = self.weight_inplace_update + weight = model_state_dict[key] - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + inplace_update = self.weight_inplace_update - if device_to is not None: - temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - utils.copy_to_param(self.model, key, out_weight) - else: - utils.set_attr(self.model, key, out_weight) - del temp_weight + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + temp_weight = memory_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + + out_weight = merge_lora_to_model_weight(current_patches, temp_weight, key).to(weight.dtype) + + if inplace_update: + utils.copy_to_param(self.model, key, out_weight) + else: + utils.set_attr(self.model, key, out_weight) + + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to return self.model - def calculate_weight(self, patches, weight, key): - for p in patches: - strength = p[0] - v = p[1] - strength_model = p[2] - offset = p[3] - function = p[4] - if function is None: - function = lambda a: a - - old_weight = None - if offset is not None: - old_weight = weight - weight = weight.narrow(offset[0], offset[1], offset[2]) - - if strength_model != 1.0: - weight *= strength_model - - if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key),) - - patch_type = '' - - if len(v) == 1: - patch_type = "diff" - elif len(v) == 2: - patch_type = v[0] - v = v[1] - - if patch_type == "diff": - w1 = v[0] - if strength != 0.0: - if w1.shape != weight.shape: - if w1.ndim == weight.ndim == 4: - new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] - print(f'Merged with {key} channel changed to {new_shape}') - new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) - new_weight = torch.zeros(size=new_shape).to(weight) - new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight - new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff - new_weight = new_weight.contiguous().clone() - weight = new_weight - else: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) - elif patch_type == "lora": - mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32) - mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32) - dora_scale = v[4] - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - print("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32), - memory_management.cast_to_device(w1_b, weight.device, torch.float32)) - else: - w1 = memory_management.cast_to_device(w1, weight.device, torch.float32) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32), - memory_management.cast_to_device(w2_b, weight.device, torch.float32)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t2, weight.device, torch.float32), - memory_management.cast_to_device(w2_b, weight.device, torch.float32), - memory_management.cast_to_device(w2_a, weight.device, torch.float32)) - else: - w2 = memory_management.cast_to_device(w2, weight.device, torch.float32) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - print("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t1, weight.device, torch.float32), - memory_management.cast_to_device(w1b, weight.device, torch.float32), - memory_management.cast_to_device(w1a, weight.device, torch.float32)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - memory_management.cast_to_device(t2, weight.device, torch.float32), - memory_management.cast_to_device(w2b, weight.device, torch.float32), - memory_management.cast_to_device(w2a, weight.device, torch.float32)) - else: - m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32), - memory_management.cast_to_device(w1b, weight.device, torch.float32)) - m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32), - memory_management.cast_to_device(w2b, weight.device, torch.float32)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - print("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - if v[4] is not None: - alpha = v[4] / v[0].shape[0] - else: - alpha = 1.0 - - dora_scale = v[5] - - a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) - a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) - b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) - b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) - - try: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - print("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type in extra_weight_calculators: - weight = extra_weight_calculators[patch_type](weight, strength, v) - else: - print("patch type not recognized {} {}".format(patch_type, key)) - - if old_weight is not None: - weight = old_weight - - return weight - def unpatch_model(self, device_to=None): keys = list(self.backup.keys()) diff --git a/backend/patcher/lora.py b/backend/patcher/lora.py index ec87b000..ea49acbe 100644 --- a/backend/patcher/lora.py +++ b/backend/patcher/lora.py @@ -1,12 +1,18 @@ +import torch + import packages_3rdparty.webui_lora_collection.lora as lora_utils_webui import packages_3rdparty.comfyui_lora_collection.lora as lora_utils_comfyui +from backend import memory_management + class ForgeLoraCollection: # TODO pass +extra_weight_calculators = {} + lora_utils_forge = ForgeLoraCollection() lora_collection_priority = [lora_utils_forge, lora_utils_webui, lora_utils_comfyui] @@ -29,3 +35,206 @@ def model_lora_keys_clip(model, key_map={}): def model_lora_keys_unet(model, key_map={}): return get_function('model_lora_keys_unet')(model, key_map) + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength): + dora_scale = memory_management.cast_to_device(dora_scale, weight.device, torch.float32) + lora_diff *= alpha + weight_calc = weight + lora_diff.type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * weight_calc + else: + weight[:] = weight_calc + return weight + + +def merge_lora_to_model_weight(patches, weight, key): + for p in patches: + strength = p[0] + v = p[1] + strength_model = p[2] + offset = p[3] + function = p[4] + if function is None: + function = lambda a: a + + old_weight = None + if offset is not None: + old_weight = weight + weight = weight.narrow(offset[0], offset[1], offset[2]) + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (calculate_weight(v[1:], v[0].clone(), key),) + + patch_type = '' + + if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": + w1 = v[0] + if strength != 0.0: + if w1.shape != weight.shape: + if w1.ndim == weight.ndim == 4: + new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] + print(f'Merged with {key} channel changed to {new_shape}') + new_diff = strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) + new_weight = torch.zeros(size=new_shape).to(weight) + new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight + new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff + new_weight = new_weight.contiguous().clone() + weight = new_weight + else: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += strength * memory_management.cast_to_device(w1, weight.device, weight.dtype) + elif patch_type == "lora": + mat1 = memory_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = memory_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + mat3 = memory_management.cast_to_device(v[3], weight.device, torch.float32) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "lokr": + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(memory_management.cast_to_device(w1_a, weight.device, torch.float32), + memory_management.cast_to_device(w1_b, weight.device, torch.float32)) + else: + w1 = memory_management.cast_to_device(w1, weight.device, torch.float32) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(memory_management.cast_to_device(w2_a, weight.device, torch.float32), + memory_management.cast_to_device(w2_b, weight.device, torch.float32)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + memory_management.cast_to_device(t2, weight.device, torch.float32), + memory_management.cast_to_device(w2_b, weight.device, torch.float32), + memory_management.cast_to_device(w2_a, weight.device, torch.float32)) + else: + w2 = memory_management.cast_to_device(w2, weight.device, torch.float32) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "loha": + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + memory_management.cast_to_device(t1, weight.device, torch.float32), + memory_management.cast_to_device(w1b, weight.device, torch.float32), + memory_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + memory_management.cast_to_device(t2, weight.device, torch.float32), + memory_management.cast_to_device(w2b, weight.device, torch.float32), + memory_management.cast_to_device(w2a, weight.device, torch.float32)) + else: + m1 = torch.mm(memory_management.cast_to_device(w1a, weight.device, torch.float32), + memory_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(memory_management.cast_to_device(w2a, weight.device, torch.float32), + memory_management.cast_to_device(w2b, weight.device, torch.float32)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "glora": + if v[4] is not None: + alpha = v[4] / v[0].shape[0] + else: + alpha = 1.0 + + dora_scale = v[5] + + a1 = memory_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = memory_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = memory_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = memory_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + try: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + print("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type in extra_weight_calculators: + weight = extra_weight_calculators[patch_type](weight, strength, v) + else: + print("patch type not recognized {} {}".format(patch_type, key)) + + if old_weight is not None: + weight = old_weight + + return weight diff --git a/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py b/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py index 20ed2d57..1a3ef8cd 100644 --- a/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py +++ b/extensions-builtin/sd_forge_fooocus_inpaint/scripts/forge_fooocus_inpaint.py @@ -1,7 +1,7 @@ import os import torch import copy -import backend.patcher.base +import backend.patcher.lora from modules_forge.shared import add_supported_control_model from modules_forge.supported_controlnet import ControlModelPatcher @@ -127,5 +127,5 @@ class FooocusInpaintPatcher(ControlModelPatcher): return -backend.patcher.base.extra_weight_calculators['fooocus'] = calculate_weight_fooocus +backend.patcher.lora.extra_weight_calculators['fooocus'] = calculate_weight_fooocus add_supported_control_model(FooocusInpaintPatcher)