From 9588721197bc3c61354811eca5aff6f470b0b2f8 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 7 Feb 2024 04:49:17 -0800 Subject: [PATCH 01/10] feat: support LyCORIS BOFT --- extensions-builtin/Lora/network_oft.py | 44 ++++++++++++++++++++------ 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d1c46a4b2..8a37828cc 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,6 @@ import torch import network -from lyco_helpers import factorization +from lyco_helpers import factorization, butterfly_factor from einops import rearrange @@ -36,6 +36,12 @@ class NetworkModuleOFT(network.NetworkModule): # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + self.is_boft = False + if "boft" in weights.w.keys(): + self.is_boft = True + self.boft_b = weights.w["boft_b"] + self.boft_m = weights.w["boft_m"] + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported @@ -68,14 +74,34 @@ class NetworkModuleOFT(network.NetworkModule): R = oft_blocks.to(orig_weight.device) - # This errors out for MultiheadAttention, might need to be handled up-stream - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if not self.is_boft: + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + else: + scale = 1.0 + m = self.boft_m.to(device=oft_blocks.device, dtype=oft_blocks.dtype) + b = self.boft_b.to(device=oft_blocks.device, dtype=oft_blocks.dtype) + r_b = b // 2 + inp = orig_weight + for i in range(m): + bi = R[i] # b_num, b_size, b_size + if i == 0: + # Apply multiplier/scale and rescale into first weight + bi = bi * scale + (1 - scale) * eye + #if self.rescaled: + # bi = bi * self.rescale + inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) + inp = rearrange(inp, "(d b) ... -> d b ...", b=b) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = rearrange(inp, "d b ... -> (d b) ...") + inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) + merged_weight = inp updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape From a4668a16b6f8e98bc6e1553aa754735f9148770f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 7 Feb 2024 04:51:22 -0800 Subject: [PATCH 02/10] fix: calculate butterfly factor --- extensions-builtin/Lora/network_oft.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 8a37828cc..0f20d701b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,6 +57,9 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim self.num_blocks = self.dim self.block_size = self.out_dim // self.dim + elif self.is_boft: + self.constraint = None + self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) From 81c16c965e532c6d86a969284c320ff8fcb0451d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 7 Feb 2024 04:54:14 -0800 Subject: [PATCH 03/10] fix: add butterfly_factor fn --- extensions-builtin/Lora/lyco_helpers.py | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 1679a0ce6..3c4f5bad2 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -66,3 +66,29 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: n, m = m, n return m, n +# from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/boft.py +def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]: + """ + m = 2k + n = 2**p + m*n = dim + """ + + # Find the first solution and check if it is even doable + m = n = 0 + while m <= factor: + m += 2 + while dimension % m != 0 and m < dimension: + m += 2 + if m > factor: + break + if sum(int(i) for i in f"{dimension//m:b}") == 1: + n = dimension // m + + if n == 0: + raise ValueError( + f"It is impossible to decompose {dimension} with factor {factor} under BOFT constrains." + ) + + #log_butterfly_factorize(dimension, factor, (dimension // n, n)) + return dimension // n, n From 2f1073dc6edf2d1388f6aee4af91cb354099a463 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 7 Feb 2024 04:55:11 -0800 Subject: [PATCH 04/10] style: fix lint --- extensions-builtin/Lora/network_oft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 0f20d701b..dc6db56f1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -96,7 +96,7 @@ class NetworkModuleOFT(network.NetworkModule): bi = R[i] # b_num, b_size, b_size if i == 0: # Apply multiplier/scale and rescale into first weight - bi = bi * scale + (1 - scale) * eye + bi = bi * scale + (1 - scale) * eye #if self.rescaled: # bi = bi * self.rescale inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) From 325eaeb584f8565d49ce73553165088f794d3d12 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:55:05 -0800 Subject: [PATCH 05/10] fix: get boft params from weight shape --- extensions-builtin/Lora/network_oft.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index dc6db56f1..fc7132651 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,6 @@ import torch import network -from lyco_helpers import factorization, butterfly_factor +from lyco_helpers import factorization from einops import rearrange @@ -37,10 +37,8 @@ class NetworkModuleOFT(network.NetworkModule): self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) self.is_boft = False - if "boft" in weights.w.keys(): + if weights.w["oft_diag"].dim() == 4: self.is_boft = True - self.boft_b = weights.w["boft_b"] - self.boft_m = weights.w["boft_m"] is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] @@ -59,7 +57,11 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size = self.out_dim // self.dim elif self.is_boft: self.constraint = None - self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim) + self.boft_m = weights.w["oft_diag"].shape[0] + self.block_num = weights.w["oft_diag"].shape[1] + self.block_size = weights.w["oft_diag"].shape[2] + self.boft_b = self.block_size + #self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) @@ -88,8 +90,8 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') else: scale = 1.0 - m = self.boft_m.to(device=oft_blocks.device, dtype=oft_blocks.dtype) - b = self.boft_b.to(device=oft_blocks.device, dtype=oft_blocks.dtype) + m = self.boft_m + b = self.boft_b r_b = b // 2 inp = orig_weight for i in range(m): From 613b0d9548a859408433bff7a6dca7fd0f2eae7e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:58:59 -0800 Subject: [PATCH 06/10] doc: add boft comment --- extensions-builtin/Lora/network_oft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fc7132651..d7b317029 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,13 +29,14 @@ class NetworkModuleOFT(network.NetworkModule): self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - # LyCORIS + # LyCORIS OFT elif "oft_diag" in weights.w.keys(): self.is_kohya = False self.oft_blocks = weights.w["oft_diag"] # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + # LyCORIS BOFT self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True @@ -89,6 +90,7 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') else: + # TODO: determine correct value for scale scale = 1.0 m = self.boft_m b = self.boft_b @@ -99,8 +101,6 @@ class NetworkModuleOFT(network.NetworkModule): if i == 0: # Apply multiplier/scale and rescale into first weight bi = bi * scale + (1 - scale) * eye - #if self.rescaled: - # bi = bi * self.rescale inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) inp = rearrange(inp, "(d b) ... -> d b ...", b=b) inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) From eb6f2df826087fdc62f6680364a0e16f666eef64 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 8 Feb 2024 22:00:15 -0800 Subject: [PATCH 07/10] Revert "fix: add butterfly_factor fn" This reverts commit 81c16c965e532c6d86a969284c320ff8fcb0451d. --- extensions-builtin/Lora/lyco_helpers.py | 26 ------------------------- 1 file changed, 26 deletions(-) diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 3c4f5bad2..1679a0ce6 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -66,29 +66,3 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: n, m = m, n return m, n -# from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/boft.py -def butterfly_factor(dimension: int, factor: int = -1) -> tuple[int, int]: - """ - m = 2k - n = 2**p - m*n = dim - """ - - # Find the first solution and check if it is even doable - m = n = 0 - while m <= factor: - m += 2 - while dimension % m != 0 and m < dimension: - m += 2 - if m > factor: - break - if sum(int(i) for i in f"{dimension//m:b}") == 1: - n = dimension // m - - if n == 0: - raise ValueError( - f"It is impossible to decompose {dimension} with factor {factor} under BOFT constrains." - ) - - #log_butterfly_factorize(dimension, factor, (dimension // n, n)) - return dimension // n, n From 90441294db16383bce6f341e8a1f67fe422172d4 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 12 Feb 2024 14:25:09 +0800 Subject: [PATCH 08/10] Add rescale mechanism LyCORIS will support save oft_blocks instead of oft_diag in the near future (for both OFT and BOFT) But this means we need to store the rescale if user enable it. --- extensions-builtin/Lora/network_oft.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d7b317029..ed221d8fe 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -40,6 +40,7 @@ class NetworkModuleOFT(network.NetworkModule): self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True + self.rescale = weight.w.get('rescale', None) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] @@ -108,6 +109,10 @@ class NetworkModuleOFT(network.NetworkModule): inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) merged_weight = inp + # Rescale mechanism + if self.rescale is not None: + merged_weight = self.rescale.to(merged_weight) * merged_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 5a8dd0c549c0221cd3ee1c53816aa52cf7b3b0ae Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 18 Feb 2024 14:58:41 +0800 Subject: [PATCH 09/10] Fix rescale --- extensions-builtin/Lora/network_oft.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ed221d8fe..f5e657b82 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -40,7 +40,9 @@ class NetworkModuleOFT(network.NetworkModule): self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True - self.rescale = weight.w.get('rescale', None) + self.rescale = weights.w.get('rescale', None) + if self.rescale is not None: + self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] From 4eb949625c8cc04ba579fc5486cc10acd541596b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 19 Feb 2024 14:43:07 +0800 Subject: [PATCH 10/10] prevent undefined variable --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f5e657b82..d658ad109 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -22,6 +22,8 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.scale = 1.0 + self.is_kohya = False + self.is_boft = False # kohya-ss if "oft_blocks" in weights.w.keys(): @@ -31,13 +33,11 @@ class NetworkModuleOFT(network.NetworkModule): self.dim = self.oft_blocks.shape[0] # lora dim # LyCORIS OFT elif "oft_diag" in weights.w.keys(): - self.is_kohya = False self.oft_blocks = weights.w["oft_diag"] # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) # LyCORIS BOFT - self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True self.rescale = weights.w.get('rescale', None)