From 8bc539a23a64b53cafceeec6a6b5b4d9d98a6d2d Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 14:18:36 -0800 Subject: [PATCH] i --- extensions-builtin/Lora/lyco_helpers.py | 68 ------- extensions-builtin/Lora/network.py | 190 ------------------ extensions-builtin/Lora/network_full.py | 27 --- extensions-builtin/Lora/network_glora.py | 33 --- extensions-builtin/Lora/network_hada.py | 55 ----- extensions-builtin/Lora/network_ia3.py | 30 --- extensions-builtin/Lora/network_lokr.py | 64 ------ extensions-builtin/Lora/network_lora.py | 86 -------- extensions-builtin/Lora/network_norm.py | 28 --- extensions-builtin/Lora/network_oft.py | 82 -------- extensions-builtin/Lora/networks.py | 100 +-------- .../Lora/scripts/lora_script.py | 1 - 12 files changed, 6 insertions(+), 758 deletions(-) delete mode 100644 extensions-builtin/Lora/lyco_helpers.py delete mode 100644 extensions-builtin/Lora/network.py delete mode 100644 extensions-builtin/Lora/network_full.py delete mode 100644 extensions-builtin/Lora/network_glora.py delete mode 100644 extensions-builtin/Lora/network_hada.py delete mode 100644 extensions-builtin/Lora/network_ia3.py delete mode 100644 extensions-builtin/Lora/network_lokr.py delete mode 100644 extensions-builtin/Lora/network_lora.py delete mode 100644 extensions-builtin/Lora/network_norm.py delete mode 100644 extensions-builtin/Lora/network_oft.py diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py deleted file mode 100644 index 1679a0ce..00000000 --- a/extensions-builtin/Lora/lyco_helpers.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch - - -def make_weight_cp(t, wa, wb): - temp = torch.einsum('i j k l, j r -> i r k l', t, wb) - return torch.einsum('i j k l, i r -> r j k l', temp, wa) - - -def rebuild_conventional(up, down, shape, dyn_dim=None): - up = up.reshape(up.size(0), -1) - down = down.reshape(down.size(0), -1) - if dyn_dim is not None: - up = up[:, :dyn_dim] - down = down[:dyn_dim, :] - return (up @ down).reshape(shape) - - -def rebuild_cp_decomposition(up, down, mid): - up = up.reshape(up.size(0), -1) - down = down.reshape(down.size(0), -1) - return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) - - -# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: - ''' - return a tuple of two value of input dimension decomposed by the number closest to factor - second value is higher or equal than first value. - - In LoRA with Kroneckor Product, first value is a value for weight scale. - secon value is a value for weight. - - Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - - examples) - factor - -1 2 4 8 16 ... - 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 - 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 - 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 - 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 - 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 - 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 - ''' - - if factor > 0 and (dimension % factor) == 0: - m = factor - n = dimension // factor - if m > n: - n, m = m, n - return m, n - if factor < 0: - factor = dimension - m, n = 1, dimension - length = m + n - while m length or new_m>factor: - break - else: - m, n = new_m, new_n - if m > n: - n, m = m, n - return m, n - diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py deleted file mode 100644 index b8fd9194..00000000 --- a/extensions-builtin/Lora/network.py +++ /dev/null @@ -1,190 +0,0 @@ -from __future__ import annotations -import os -from collections import namedtuple -import enum - -import torch.nn as nn -import torch.nn.functional as F - -from modules import sd_models, cache, errors, hashes, shared - -NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) - -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} - - -class SdVersion(enum.Enum): - Unknown = 1 - SD1 = 2 - SD2 = 3 - SDXL = 4 - - -class NetworkOnDisk: - def __init__(self, name, filename): - self.name = name - self.filename = filename - self.metadata = {} - self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - - def read_metadata(): - metadata = sd_models.read_metadata_from_safetensors(filename) - metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text - - return metadata - - if self.is_safetensors: - try: - self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) - except Exception as e: - errors.display(e, f"reading lora {filename}") - - if self.metadata: - m = {} - for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): - m[k] = v - - self.metadata = m - - self.alias = self.metadata.get('ss_output_name', self.name) - - self.hash = None - self.shorthash = None - self.set_hash( - self.metadata.get('sshs_model_hash') or - hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or - '' - ) - - self.sd_version = self.detect_version() - - def detect_version(self): - if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): - return SdVersion.SDXL - elif str(self.metadata.get('ss_v2', "")) == "True": - return SdVersion.SD2 - elif len(self.metadata): - return SdVersion.SD1 - - return SdVersion.Unknown - - def set_hash(self, v): - self.hash = v - self.shorthash = self.hash[0:12] - - if self.shorthash: - import networks - networks.available_network_hash_lookup[self.shorthash] = self - - def read_hash(self): - if not self.hash: - self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') - - def get_alias(self): - import networks - if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: - return self.name - else: - return self.alias - - -class Network: # LoraModule - def __init__(self, name, network_on_disk: NetworkOnDisk): - self.name = name - self.network_on_disk = network_on_disk - self.te_multiplier = 1.0 - self.unet_multiplier = 1.0 - self.dyn_dim = None - self.modules = {} - self.bundle_embeddings = {} - self.mtime = None - - self.mentioned_name = None - """the text that was used to add the network to prompt - can be either name or an alias""" - - -class ModuleType: - def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: - return None - - -class NetworkModule: - def __init__(self, net: Network, weights: NetworkWeights): - self.network = net - self.network_key = weights.network_key - self.sd_key = weights.sd_key - self.sd_module = weights.sd_module - - if hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - - self.ops = None - self.extra_kwargs = {} - if isinstance(self.sd_module, nn.Conv2d): - self.ops = F.conv2d - self.extra_kwargs = { - 'stride': self.sd_module.stride, - 'padding': self.sd_module.padding - } - elif isinstance(self.sd_module, nn.Linear): - self.ops = F.linear - elif isinstance(self.sd_module, nn.LayerNorm): - self.ops = F.layer_norm - self.extra_kwargs = { - 'normalized_shape': self.sd_module.normalized_shape, - 'eps': self.sd_module.eps - } - elif isinstance(self.sd_module, nn.GroupNorm): - self.ops = F.group_norm - self.extra_kwargs = { - 'num_groups': self.sd_module.num_groups, - 'eps': self.sd_module.eps - } - - self.dim = None - self.bias = weights.w.get("bias") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - - def multiplier(self): - if 'transformer' in self.sd_key[:20]: - return self.network.te_multiplier - else: - return self.network.unet_multiplier - - def calc_scale(self): - if self.scale is not None: - return self.scale - if self.dim is not None and self.alpha is not None: - return self.alpha / self.dim - - return 1.0 - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=updown.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - return updown * self.calc_scale() * self.multiplier(), ex_bias - - def calc_updown(self, target): - raise NotImplementedError() - - def forward(self, x, y): - """A general forward implementation for all modules""" - if self.ops is None: - raise NotImplementedError() - else: - updown, ex_bias = self.calc_updown(self.sd_module.weight) - return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) - diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py deleted file mode 100644 index f221c95f..00000000 --- a/extensions-builtin/Lora/network_full.py +++ /dev/null @@ -1,27 +0,0 @@ -import network - - -class ModuleTypeFull(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["diff"]): - return NetworkModuleFull(net, weights) - - return None - - -class NetworkModuleFull(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - self.weight = weights.w.get("diff") - self.ex_bias = weights.w.get("diff_b") - - def calc_updown(self, orig_weight): - output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device) - if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device) - else: - ex_bias = None - - return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py deleted file mode 100644 index efe5c681..00000000 --- a/extensions-builtin/Lora/network_glora.py +++ /dev/null @@ -1,33 +0,0 @@ - -import network - -class ModuleTypeGLora(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): - return NetworkModuleGLora(net, weights) - - return None - -# adapted from https://github.com/KohakuBlueleaf/LyCORIS -class NetworkModuleGLora(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - if hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - - self.w1a = weights.w["a1.weight"] - self.w1b = weights.w["b1.weight"] - self.w2a = weights.w["a2.weight"] - self.w2b = weights.w["b2.weight"] - - def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device) - w1b = self.w1b.to(orig_weight.device) - w2a = self.w2a.to(orig_weight.device) - w2b = self.w2b.to(orig_weight.device) - - output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) - - return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py deleted file mode 100644 index d95a0fd1..00000000 --- a/extensions-builtin/Lora/network_hada.py +++ /dev/null @@ -1,55 +0,0 @@ -import lyco_helpers -import network - - -class ModuleTypeHada(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): - return NetworkModuleHada(net, weights) - - return None - - -class NetworkModuleHada(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - if hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - - self.w1a = weights.w["hada_w1_a"] - self.w1b = weights.w["hada_w1_b"] - self.dim = self.w1b.shape[0] - self.w2a = weights.w["hada_w2_a"] - self.w2b = weights.w["hada_w2_b"] - - self.t1 = weights.w.get("hada_t1") - self.t2 = weights.w.get("hada_t2") - - def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device) - w1b = self.w1b.to(orig_weight.device) - w2a = self.w2a.to(orig_weight.device) - w2b = self.w2b.to(orig_weight.device) - - output_shape = [w1a.size(0), w1b.size(1)] - - if self.t1 is not None: - output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device) - updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) - output_shape += t1.shape[2:] - else: - if len(w1b.shape) == 4: - output_shape += w1b.shape[2:] - updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) - - if self.t2 is not None: - t2 = self.t2.to(orig_weight.device) - updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) - else: - updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) - - updown = updown1 * updown2 - - return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py deleted file mode 100644 index 96faeaf3..00000000 --- a/extensions-builtin/Lora/network_ia3.py +++ /dev/null @@ -1,30 +0,0 @@ -import network - - -class ModuleTypeIa3(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["weight"]): - return NetworkModuleIa3(net, weights) - - return None - - -class NetworkModuleIa3(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - self.w = weights.w["weight"] - self.on_input = weights.w["on_input"].item() - - def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device) - - output_shape = [w.size(0), orig_weight.size(1)] - if self.on_input: - output_shape.reverse() - else: - w = w.reshape(-1, 1) - - updown = orig_weight * w - - return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py deleted file mode 100644 index fcdaeafd..00000000 --- a/extensions-builtin/Lora/network_lokr.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch - -import lyco_helpers -import network - - -class ModuleTypeLokr(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) - has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) - if has_1 and has_2: - return NetworkModuleLokr(net, weights) - - return None - - -def make_kron(orig_shape, w1, w2): - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - w2 = w2.contiguous() - return torch.kron(w1, w2).reshape(orig_shape) - - -class NetworkModuleLokr(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - self.w1 = weights.w.get("lokr_w1") - self.w1a = weights.w.get("lokr_w1_a") - self.w1b = weights.w.get("lokr_w1_b") - self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim - self.w2 = weights.w.get("lokr_w2") - self.w2a = weights.w.get("lokr_w2_a") - self.w2b = weights.w.get("lokr_w2_b") - self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim - self.t2 = weights.w.get("lokr_t2") - - def calc_updown(self, orig_weight): - if self.w1 is not None: - w1 = self.w1.to(orig_weight.device) - else: - w1a = self.w1a.to(orig_weight.device) - w1b = self.w1b.to(orig_weight.device) - w1 = w1a @ w1b - - if self.w2 is not None: - w2 = self.w2.to(orig_weight.device) - elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device) - w2b = self.w2b.to(orig_weight.device) - w2 = w2a @ w2b - else: - t2 = self.t2.to(orig_weight.device) - w2a = self.w2a.to(orig_weight.device) - w2b = self.w2b.to(orig_weight.device) - w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) - - output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] - if len(orig_weight.shape) == 4: - output_shape = orig_weight.shape - - updown = make_kron(output_shape, w1, w2) - - return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py deleted file mode 100644 index 4cc40295..00000000 --- a/extensions-builtin/Lora/network_lora.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch - -import lyco_helpers -import network -from modules import devices - - -class ModuleTypeLora(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): - return NetworkModuleLora(net, weights) - - return None - - -class NetworkModuleLora(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - self.up_model = self.create_module(weights.w, "lora_up.weight") - self.down_model = self.create_module(weights.w, "lora_down.weight") - self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) - - self.dim = weights.w["lora_down.weight"].shape[0] - - def create_module(self, weights, key, none_ok=False): - weight = weights.get(key) - - if weight is None and none_ok: - return None - - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] - is_conv = type(self.sd_module) in [torch.nn.Conv2d] - - if is_linear: - weight = weight.reshape(weight.shape[0], -1) - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif is_conv and key == "lora_down.weight" or key == "dyn_up": - if len(weight.shape) == 2: - weight = weight.reshape(weight.shape[0], -1, 1, 1) - - if weight.shape[2] != 1 or weight.shape[3] != 1: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - else: - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif is_conv and key == "lora_mid.weight": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) - elif is_conv and key == "lora_up.weight" or key == "dyn_down": - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - else: - raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - - with torch.no_grad(): - if weight.shape != module.weight.shape: - weight = weight.reshape(module.weight.shape) - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - module.weight.requires_grad_(False) - - return module - - def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device) - down = self.down_model.weight.to(orig_weight.device) - - output_shape = [up.size(0), down.size(1)] - if self.mid_model is not None: - # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device) - updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) - output_shape += mid.shape[2:] - else: - if len(down.shape) == 4: - output_shape += down.shape[2:] - updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) - - return self.finalize_updown(updown, orig_weight, output_shape) - - def forward(self, x, y): - self.up_model.to(device=devices.device) - self.down_model.to(device=devices.device) - - return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() - - diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py deleted file mode 100644 index d25afcbb..00000000 --- a/extensions-builtin/Lora/network_norm.py +++ /dev/null @@ -1,28 +0,0 @@ -import network - - -class ModuleTypeNorm(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["w_norm", "b_norm"]): - return NetworkModuleNorm(net, weights) - - return None - - -class NetworkModuleNorm(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - self.w_norm = weights.w.get("w_norm") - self.b_norm = weights.w.get("b_norm") - - def calc_updown(self, orig_weight): - output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device) - - if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device) - else: - ex_bias = None - - return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py deleted file mode 100644 index d1c46a4b..00000000 --- a/extensions-builtin/Lora/network_oft.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import network -from lyco_helpers import factorization -from einops import rearrange - - -class ModuleTypeOFT(network.ModuleType): - def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): - return NetworkModuleOFT(net, weights) - - return None - -# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py -# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py -class NetworkModuleOFT(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - - super().__init__(net, weights) - - self.lin_module = None - self.org_module: list[torch.Module] = [self.sd_module] - - self.scale = 1.0 - - # kohya-ss - if "oft_blocks" in weights.w.keys(): - self.is_kohya = True - self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) - self.alpha = weights.w["alpha"] # alpha is constraint - self.dim = self.oft_blocks.shape[0] # lora dim - # LyCORIS - elif "oft_diag" in weights.w.keys(): - self.is_kohya = False - self.oft_blocks = weights.w["oft_diag"] - # self.alpha is unused - self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) - - is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] - is_conv = type(self.sd_module) in [torch.nn.Conv2d] - is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported - - if is_linear: - self.out_dim = self.sd_module.out_features - elif is_conv: - self.out_dim = self.sd_module.out_channels - elif is_other_linear: - self.out_dim = self.sd_module.embed_dim - - if self.is_kohya: - self.constraint = self.alpha * self.out_dim - self.num_blocks = self.dim - self.block_size = self.out_dim // self.dim - else: - self.constraint = None - self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - - def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device) - eye = torch.eye(self.block_size, device=oft_blocks.device) - - if self.is_kohya: - block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - - R = oft_blocks.to(orig_weight.device) - - # This errors out for MultiheadAttention, might need to be handled up-stream - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - - updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) - output_shape = orig_weight.shape - return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 84c0169d..0ccd6a90 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -2,7 +2,7 @@ import os import re import lora_patches -import network +import functools import torch from typing import Union @@ -12,93 +12,13 @@ from ldm_patched.modules.utils import load_torch_file from ldm_patched.modules.sd import load_lora_for_models -lora_state_dict_cache = {} -lora_state_dict_cache_max_length = 5 - -module_types = [] - - -re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") -re_compiled = {} - -suffix_conversion = { - "attentions": {}, - "resnets": { - "conv1": "in_layers_2", - "conv2": "out_layers_3", - "norm1": "in_layers_0", - "norm2": "out_layers_0", - "time_emb_proj": "emb_layers_1", - "conv_shortcut": "skip_connection", - } -} +@functools.lru_cache(maxsize=5) +def load_lora_state_dict(filename): + return load_torch_file(filename, safe_load=True) def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex_text): - regex = re_compiled.get(regex_text) - if regex is None: - regex = re.compile(regex_text) - re_compiled[regex_text] = regex - - r = re.match(regex, key) - if not r: - return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True - - m = [] - - if match(m, r"lora_unet_conv_in(.*)"): - return f'diffusion_model_input_blocks_0_0{m[0]}' - - if match(m, r"lora_unet_conv_out(.*)"): - return f'diffusion_model_out_2{m[0]}' - - if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): - return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" - - if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) - return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - - if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): - return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - - if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): - return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" - - if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): - if is_sd2: - if 'mlp_fc1' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" - - if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): - if 'mlp_fc1' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" - else: - return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - return key + pass def assign_network_names_to_compvis_modules(sd_model): @@ -139,15 +59,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No current_sd.forge_objects.clip = current_sd.forge_objects.clip_original for filename, strength_model, strength_clip in compiled_lora_targets: - if filename in lora_state_dict_cache: - lora_sd = lora_state_dict_cache[filename] - else: - if len(lora_state_dict_cache) > lora_state_dict_cache_max_length: - lora_state_dict_cache = {} - - lora_sd = load_torch_file(filename, safe_load=True) - lora_state_dict_cache[filename] = lora_sd - + lora_sd = load_lora_state_dict(filename) current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 1518f7e5..c2ba5c73 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -3,7 +3,6 @@ import re import gradio as gr from fastapi import FastAPI -import network import networks import lora # noqa:F401 import lora_patches