From b16819f8e789a5c89d82745b4a915278c86bff30 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 2 Mar 2025 06:57:50 -0700 Subject: [PATCH 1/6] Added LoKr support --- toolkit/config_modules.py | 9 ++ toolkit/lora_special.py | 23 +++- toolkit/models/lokr.py | 264 ++++++++++++++++++++++---------------- toolkit/network_mixins.py | 32 ++++- 4 files changed, 215 insertions(+), 113 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3fc7728c..2c118f1c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -135,6 +135,15 @@ class NetworkConfig: self.conv = 4 self.transformer_only = kwargs.get('transformer_only', True) + + self.lokr_full_rank = kwargs.get('lokr_full_rank', False) + if self.lokr_full_rank: + self.linear = 9999999999 + self.linear_alpha = 9999999999 + self.conv = 9999999999 + self.conv_alpha = 9999999999 + # -1 automatically finds the largest factor + self.lokr_factor = kwargs.get('lokr_factor', -1) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net'] diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index b37ed098..ace0a2ec 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -231,12 +231,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.network_type.lower() == "dora": self.module_class = DoRAModule module_class = DoRAModule + elif self.network_type.lower() == "lokr": + self.module_class = LokrModule + module_class = LokrModule + self.network_config: NetworkConfig = kwargs.get("network_config", None) self.peft_format = peft_format # always do peft for flux only for now if self.is_flux or self.is_v3 or self.is_lumina2: - self.peft_format = True + # don't do peft format for lokr + if self.network_type.lower() != "lokr": + self.peft_format = True if self.peft_format: # no alpha for peft @@ -373,6 +379,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.conv_lora_dim is not None or conv_block_dims is not None): skipped.append(lora_name) continue + + module_kwargs = {} + + if self.network_type.lower() == "lokr": + module_kwargs["factor"] = self.network_config.lokr_factor lora = module_class( lora_name, @@ -386,10 +397,16 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): network=self, parent=module, use_bias=use_bias, + **module_kwargs ) loras.append(lora) - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape) - ] + if self.network_type.lower() == "lokr": + try: + lora_shape_dict[lora_name] = [list(lora.lokr_w1.weight.shape), list(lora.lokr_w2.weight.shape)] + except: + pass + else: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] diff --git a/toolkit/models/lokr.py b/toolkit/models/lokr.py index b736406e..0240a97b 100644 --- a/toolkit/models/lokr.py +++ b/toolkit/models/lokr.py @@ -10,24 +10,23 @@ from toolkit.network_mixins import ToolkitModuleMixin from typing import TYPE_CHECKING, Union, List +from optimum.quanto import QBytesTensor, QTensor + if TYPE_CHECKING: - + from toolkit.lora_special import LoRASpecialNetwork -# 4, build custom backward function -# - - -def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: +def factorization(dimension: int, factor: int = -1) -> tuple[int, int]: ''' return a tuple of two value of input dimension decomposed by the number closest to factor second value is higher or equal than first value. - + In LoRA with Kroneckor Product, first value is a value for weight scale. secon value is a value for weight. - + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. - + examples) factor -1 2 4 8 16 ... @@ -38,7 +37,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: 512 -> 32, 16 512 -> 256, 2 512 -> 128, 4 512 -> 64, 8 512 -> 32, 16 1024 -> 32, 32 1024 -> 512, 2 1024 -> 256, 4 1024 -> 128, 8 1024 -> 64, 16 ''' - + if factor > 0 and (dimension % factor) == 0: m = factor n = dimension // factor @@ -47,12 +46,12 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: factor = dimension m, n = 1, dimension length = m + n - while m length or new_m>factor: + if new_m + new_n > length or new_m > factor: break else: m, n = new_m, new_n @@ -62,7 +61,8 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: def make_weight_cp(t, wa, wb): - rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', t, wa, wb) # [c, d, k1, k2] + rebuild2 = torch.einsum('i j k l, i p, j r -> p r k l', + t, wa, wb) # [c, d, k1, k2] return rebuild2 @@ -71,31 +71,25 @@ def make_kron(w1, w2, scale): w1 = w1.unsqueeze(2).unsqueeze(2) w2 = w2.contiguous() rebuild = torch.kron(w1, w2) - + return rebuild*scale class LokrModule(ToolkitModuleMixin, nn.Module): - """ - modifed from kohya-ss/sd-scripts/networks/lora:LoRAModule - and from KohakuBlueleaf/LyCORIS/lycoris:loha:LoHaModule - and from KohakuBlueleaf/LyCORIS/lycoris:locon:LoconModule - """ - def __init__( - self, - lora_name, - org_module: nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=0., - rank_dropout=0., + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=0., + rank_dropout=0., module_dropout=0., use_cp=False, - decompose_both = False, + decompose_both=False, network: 'LoRASpecialNetwork' = None, - factor:int=-1, # factorization factor + factor: int = -1, # factorization factor **kwargs, ): """ if alpha == 0 or None, alpha is rank (no scaling). """ @@ -113,32 +107,42 @@ class LokrModule(ToolkitModuleMixin, nn.Module): in_dim = org_module.in_channels k_size = org_module.kernel_size out_dim = org_module.out_channels - + in_m, in_n = factorization(in_dim, factor) out_l, out_k = factorization(out_dim, factor) - shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) - - self.cp = use_cp and k_size!=(1, 1) + # ((a, b), (c, d), *k_size) + shape = ((out_l, out_k), (in_m, in_n), *k_size) + + self.cp = use_cp and k_size != (1, 1) if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: - self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) - self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) + self.lokr_w1_a = nn.Parameter( + torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter( + torch.empty(lora_dim, shape[1][0])) else: self.use_w1 = True - self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode - + self.lokr_w1 = nn.Parameter(torch.empty( + shape[0][0], shape[1][0])) # a*c, 1-mode + if lora_dim >= max(shape[0][1], shape[1][1])/2: self.use_w2 = True - self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *k_size)) + self.lokr_w2 = nn.Parameter(torch.empty( + shape[0][1], shape[1][1], *k_size)) elif self.cp: - self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, shape[2], shape[3])) - self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, shape[0][1])) # b, 1-mode - self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) # d, 2-mode - else: # Conv2d not cp + self.lokr_t2 = nn.Parameter(torch.empty( + lora_dim, lora_dim, shape[2], shape[3])) + self.lokr_w2_a = nn.Parameter( + torch.empty(lora_dim, shape[0][1])) # b, 1-mode + self.lokr_w2_b = nn.Parameter( + torch.empty(lora_dim, shape[1][1])) # d, 2-mode + else: # Conv2d not cp # bigger part. weight and LoRA. [b, dim] x [dim, d*k1*k2] - self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) - self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1]*shape[2]*shape[3])) + self.lokr_w2_a = nn.Parameter( + torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty( + lora_dim, shape[1][1]*shape[2]*shape[3])) # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d*k1*k2)) = (a, b)⊗(c, d*k1*k2) = (ac, bd*k1*k2) - + self.op = F.conv2d self.extra_args = { "stride": org_module.stride, @@ -147,48 +151,55 @@ class LokrModule(ToolkitModuleMixin, nn.Module): "groups": org_module.groups } - else: # Linear + else: # Linear in_dim = org_module.in_features out_dim = org_module.out_features - + in_m, in_n = factorization(in_dim, factor) out_l, out_k = factorization(out_dim, factor) - shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d - + # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d + shape = ((out_l, out_k), (in_m, in_n)) + # smaller part. weight scale if decompose_both and lora_dim < max(shape[0][0], shape[1][0])/2: - self.lokr_w1_a = nn.Parameter(torch.empty(shape[0][0], lora_dim)) - self.lokr_w1_b = nn.Parameter(torch.empty(lora_dim, shape[1][0])) + self.lokr_w1_a = nn.Parameter( + torch.empty(shape[0][0], lora_dim)) + self.lokr_w1_b = nn.Parameter( + torch.empty(lora_dim, shape[1][0])) else: self.use_w1 = True - self.lokr_w1 = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) # a*c, 1-mode + self.lokr_w1 = nn.Parameter(torch.empty( + shape[0][0], shape[1][0])) # a*c, 1-mode if lora_dim < max(shape[0][1], shape[1][1])/2: # bigger part. weight and LoRA. [b, dim] x [dim, d] - self.lokr_w2_a = nn.Parameter(torch.empty(shape[0][1], lora_dim)) - self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, shape[1][1])) + self.lokr_w2_a = nn.Parameter( + torch.empty(shape[0][1], lora_dim)) + self.lokr_w2_b = nn.Parameter( + torch.empty(lora_dim, shape[1][1])) # w1 ⊗ (w2_a x w2_b) = (a, b)⊗((c, dim)x(dim, d)) = (a, b)⊗(c, d) = (ac, bd) else: self.use_w2 = True - self.lokr_w2 = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) + self.lokr_w2 = nn.Parameter( + torch.empty(shape[0][1], shape[1][1])) self.op = F.linear self.extra_args = {} - + self.dropout = dropout if dropout: - print("[WARN]LoHa/LoKr haven't implemented normal dropout yet.") + print("[WARN]LoKr haven't implemented normal dropout yet.") self.rank_dropout = rank_dropout self.module_dropout = module_dropout - + if isinstance(alpha, torch.Tensor): alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = lora_dim if alpha is None or alpha == 0 else alpha if self.use_w2 and self.use_w1: - #use scale = 1 + # use scale = 1 alpha = lora_dim self.scale = alpha / self.lora_dim - self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + self.register_buffer('alpha', torch.tensor(alpha)) # treat as constant if self.use_w2: torch.nn.init.constant_(self.lokr_w2, 0) @@ -197,7 +208,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module): torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) torch.nn.init.constant_(self.lokr_w2_b, 0) - + if self.use_w1: torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) else: @@ -208,8 +219,8 @@ class LokrModule(ToolkitModuleMixin, nn.Module): self.org_module = [org_module] weight = make_kron( self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b, - (self.lokr_w2 if self.use_w2 - else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp + (self.lokr_w2 if self.use_w2 + else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp else self.lokr_w2_a@self.lokr_w2_b), torch.tensor(self.multiplier * self.scale) ) @@ -219,12 +230,12 @@ class LokrModule(ToolkitModuleMixin, nn.Module): def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - - def get_weight(self, orig_weight = None): + + def get_weight(self, orig_weight=None): weight = make_kron( self.lokr_w1 if self.use_w1 else self.lokr_w1_a@self.lokr_w1_b, - (self.lokr_w2 if self.use_w2 - else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp + (self.lokr_w2 if self.use_w2 + else make_weight_cp(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) if self.cp else self.lokr_w2_a@self.lokr_w2_b), torch.tensor(self.scale) ) @@ -232,51 +243,88 @@ class LokrModule(ToolkitModuleMixin, nn.Module): weight = weight.reshape(orig_weight.shape) if self.training and self.rank_dropout: drop = torch.rand(weight.size(0)) < self.rank_dropout - weight *= drop.view(-1, [1]*len(weight.shape[1:])).to(weight.device) + weight *= drop.view(-1, [1] * + len(weight.shape[1:])).to(weight.device) return weight @torch.no_grad() - def apply_max_norm(self, max_norm, device=None): - orig_norm = self.get_weight().norm() - norm = torch.clamp(orig_norm, max_norm/2) - desired = torch.clamp(norm, max=max_norm) - ratio = desired.cpu()/norm.cpu() - - scaled = ratio.item() != 1.0 - if scaled: - modules = (4 - self.use_w1 - self.use_w2 + (not self.use_w2 and self.cp)) - if self.use_w1: - self.lokr_w1 *= ratio**(1/modules) - else: - self.lokr_w1_a *= ratio**(1/modules) - self.lokr_w1_b *= ratio**(1/modules) - - if self.use_w2: - self.lokr_w2 *= ratio**(1/modules) - else: - if self.cp: - self.lokr_t2 *= ratio**(1/modules) - self.lokr_w2_a *= ratio**(1/modules) - self.lokr_w2_b *= ratio**(1/modules) - - return scaled, orig_norm*ratio + def merge_in(self, merge_weight=1.0): + if not self.can_merge_in: + return - def forward(self, x): - if self.module_dropout and self.training: - if torch.rand(1) < self.module_dropout: - return self.op( - x, - self.org_module[0].weight.data, - None if self.org_module[0].bias is None else self.org_module[0].bias.data - ) - weight = ( - self.org_module[0].weight.data - + self.get_weight(self.org_module[0].weight.data) * self.multiplier + # extract weight from org_module + org_sd = self.org_module[0].state_dict() + # todo find a way to merge in weights when doing quantized model + if 'weight._data' in org_sd: + # quantized weight + return + + weight_key = "weight" + if 'weight._data' in org_sd: + # quantized weight + weight_key = "weight._data" + + orig_dtype = org_sd[weight_key].dtype + weight = org_sd[weight_key].float() + + scale = self.scale + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale = scale * self.scalar + + lokr_weight = self.get_weight(weight) + + merged_weight = ( + weight + + (lokr_weight * merge_weight).to(weight.device, dtype=weight.dtype) ) - bias = None if self.org_module[0].bias is None else self.org_module[0].bias.data - return self.op( - x, + + # set weight to org_module + org_sd[weight_key] = merged_weight.to(orig_dtype) + self.org_module[0].load_state_dict(org_sd) + + def get_orig_weight(self): + weight = self.org_module[0].weight + if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): + return weight.dequantize().data.detach() + else: + return weight.data.detach() + + def get_orig_bias(self): + if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: + if isinstance(self.org_module[0].bias, QTensor) or isinstance(self.org_module[0].bias, QBytesTensor): + return self.org_module[0].bias.dequantize().data.detach() + else: + return self.org_module[0].bias.data.detach() + return None + + def _call_forward(self, x): + if isinstance(x, QTensor) or isinstance(x, QBytesTensor): + x = x.dequantize() + + orig_dtype = x.dtype + + orig_weight = self.get_orig_weight() + lokr_weight = self.get_weight(orig_weight).to(dtype=orig_weight.dtype) + multiplier = self.network_ref().torch_multiplier + + if x.dtype != orig_weight.dtype: + x = x.to(dtype=orig_weight.dtype) + + # we do not currently support split batch multipliers for lokr. Just do a mean + multiplier = torch.mean(multiplier) + + weight = ( + orig_weight + + lokr_weight * multiplier + ) + bias = self.get_orig_bias() + if bias is not None: + bias = bias.to(weight.device, dtype=weight.dtype) + output = self.op( + x, weight.view(self.shape), bias, **self.extra_args - ) \ No newline at end of file + ) + return output.to(orig_dtype) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 37f7987e..c9a73000 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -272,6 +272,9 @@ class ToolkitModuleMixin: # if self.__class__.__name__ == "DoRAModule": # # return dora forward # return self.dora_forward(x, *args, **kwargs) + + if self.__class__.__name__ == "LokrModule": + return self._call_forward(x) org_forwarded = self.org_forward(x, *args, **kwargs) @@ -540,6 +543,17 @@ class ToolkitNetworkMixin: new_save_dict[new_key] = value save_dict = new_save_dict + + + if self.network_type.lower() == "lokr": + new_save_dict = {} + for key, value in save_dict.items(): + # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 + new_key = key + new_key = new_key.replace('lora_transformer_', 'lycoris_') + new_save_dict[new_key] = value + + save_dict = new_save_dict if metadata is None: metadata = OrderedDict() @@ -585,6 +599,10 @@ class ToolkitNetworkMixin: load_key = load_key.replace('.', '$$') load_key = load_key.replace('$$lora_down$$', '.lora_down.') load_key = load_key.replace('$$lora_up$$', '.lora_up.') + + if self.network_type.lower() == "lokr": + # lora_transformer_transformer_blocks_7_attn_to_v.lokr_w1 to lycoris_transformer_blocks_7_attn_to_v.lokr_w1 + load_key = load_key.replace('lycoris_', 'lora_transformer_') load_sd[load_key] = value @@ -617,8 +635,18 @@ class ToolkitNetworkMixin: multiplier = self._multiplier # get first module first_module = self.get_all_modules()[0] - device = first_module.lora_down.weight.device - dtype = first_module.lora_down.weight.dtype + + if hasattr(first_module, 'lora_down'): + device = first_module.lora_down.weight.device + dtype = first_module.lora_down.weight.dtype + elif hasattr(first_module, 'lokr_w1'): + device = first_module.lokr_w1.device + dtype = first_module.lokr_w1.dtype + elif hasattr(first_module, 'lokr_w1_a'): + device = first_module.lokr_w1_a.device + dtype = first_module.lokr_w1_a.dtype + else: + raise ValueError("Unknown module type") with torch.no_grad(): tensor_multiplier = None if isinstance(multiplier, int) or isinstance(multiplier, float): From 7ae31c9ae907241d2c7879fc6627bcee45a97376 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 2 Mar 2025 08:49:01 -0700 Subject: [PATCH 2/6] Added LoKr to the ui --- toolkit/config_modules.py | 2 +- ui/src/app/jobs/new/jobConfig.ts | 2 ++ ui/src/app/jobs/new/page.tsx | 35 +++++++++++++++++++++++++++----- ui/src/types.ts | 4 +++- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2c118f1c..07b3e2b7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -137,7 +137,7 @@ class NetworkConfig: self.transformer_only = kwargs.get('transformer_only', True) self.lokr_full_rank = kwargs.get('lokr_full_rank', False) - if self.lokr_full_rank: + if self.lokr_full_rank and self.type.lower() == 'lokr': self.linear = 9999999999 self.linear_alpha = 9999999999 self.conv = 9999999999 diff --git a/ui/src/app/jobs/new/jobConfig.ts b/ui/src/app/jobs/new/jobConfig.ts index d87d12c4..8098171d 100644 --- a/ui/src/app/jobs/new/jobConfig.ts +++ b/ui/src/app/jobs/new/jobConfig.ts @@ -29,6 +29,8 @@ export const defaultJobConfig: JobConfig = { type: 'lora', linear: 16, linear_alpha: 16, + lokr_full_rank: true, + lokr_factor: -1 }, save: { dtype: 'bf16', diff --git a/ui/src/app/jobs/new/page.tsx b/ui/src/app/jobs/new/page.tsx index e3aec77f..4652cf20 100644 --- a/ui/src/app/jobs/new/page.tsx +++ b/ui/src/app/jobs/new/page.tsx @@ -227,8 +227,31 @@ export default function TrainingForm() { - {jobConfig.config.process[0].network?.type && ( - + + setJobConfig(value, 'config.process[0].network.type')} + options={[ + { value: 'lora', label: 'LoRA' }, + { value: 'lokr', label: 'LoKr' }, + ]} + /> + {jobConfig.config.process[0].network?.type == 'lokr' && ( + setJobConfig(parseInt(value), 'config.process[0].network.lokr_factor')} + options={[ + { value: '-1', label: 'Auto' }, + { value: '4', label: '4' }, + { value: '8', label: '8' }, + { value: '16', label: '16' }, + { value: '32', label: '32' }, + ]} + /> + )} + {jobConfig.config.process[0].network?.type == 'lora' && ( - - )} + )} + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier')} + onChange={value => + setJobConfig(value, 'config.process[0].train.diff_output_preservation_multiplier') + } placeholder="eg. 1.0" min={0} /> diff --git a/ui/src/types.ts b/ui/src/types.ts index 16ecb220..0aa7b67b 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -50,9 +50,11 @@ export interface GPUApiResponse { */ export interface NetworkConfig { - type: 'lora'; + type: string; linear: number; linear_alpha: number; + lokr_full_rank: boolean; + lokr_factor: number; } export interface SaveConfig { From b001d77efb4b87a3f03ae05fc94fe71cffc3e78e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 2 Mar 2025 08:55:56 -0700 Subject: [PATCH 3/6] Added LoKr instructions to the readme --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index f974dafc..5f0c8cb1 100644 --- a/README.md +++ b/README.md @@ -311,3 +311,17 @@ You can also exclude layers by their names by using `ignore_if_contains` network `ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both, if will be ignored. + +## LoKr Training + +To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so: + +```yaml + network: + type: "lokr" + lokr_full_rank: true + lokr_factor: 8 +``` + +Everything else should work the same including layer targeting. + From 3c8c84f15604d56239a68bf1fa11fd774e168ee1 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 2 Mar 2025 10:25:27 -0700 Subject: [PATCH 4/6] Added supporters to readme and a script to update it --- README.md | 24 ++- scripts/update_sponsors.py | 309 +++++++++++++++++++++++++++++++++++++ todo_multigpu.md | 3 - 3 files changed, 328 insertions(+), 8 deletions(-) create mode 100644 scripts/update_sponsors.py delete mode 100644 todo_multigpu.md diff --git a/README.md b/README.md index 5f0c8cb1..27d4404c 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,27 @@ # AI Toolkit by Ostris -## Support my work +## Support My Work - -Patreon - ostris - +If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! + +[Become a sponsor on GitHub](https://github.com/orgs/ostris) or [support me on Patreon](https://www.patreon.com/ostris). + +Thank you to all my current supporters! + +_Last updated: 2025-03-02_ + +## GitHub Sponsors + +Replicate + +## Patreon Supporters + +Al H clement Delangue Cosmosis David Garrido Doron Adler Eli Slugworth EmmanuelMr18 HestoySeghuro . Jack Blakely Jason Jean-Tristan Marin Jodh Singh John Dopamine Joseph Rocca Kristjan Retter Maciej Popławski Michael Levine Miguel Lara Misch Strotz Mohamed Oumoumad Noctre Patron Paul Fidika Prasanth Veerina Razvan Grigore Steve Hanff Steve Informal The Local Lab Vladimir Sotnikov Zoltán-Csaba Nyiró + + +--- -I work on open source full time, which means I 100% rely on donations to make a living. If you find this project helpful, or use it in for commercial purposes, please consider donating to support my work on [Patreon](https://www.patreon.com/ostris) or [Github Sponsors](https://github.com/sponsors/ostris). ## Installation diff --git a/scripts/update_sponsors.py b/scripts/update_sponsors.py new file mode 100644 index 00000000..1e3bbf63 --- /dev/null +++ b/scripts/update_sponsors.py @@ -0,0 +1,309 @@ +import os +import requests +import json +from datetime import datetime +from dotenv import load_dotenv + +# Load environment variables from .env file +env_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env") +load_dotenv(dotenv_path=env_path) + +# API credentials +PATREON_TOKEN = os.getenv("PATREON_ACCESS_TOKEN") +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +GITHUB_USERNAME = os.getenv("GITHUB_USERNAME") +GITHUB_ORG = os.getenv("GITHUB_ORG") # Organization name (optional) + +# Output file +README_PATH = "SUPPORTERS.md" + +def fetch_patreon_supporters(): + """Fetch current Patreon supporters""" + print("Fetching Patreon supporters...") + + headers = { + "Authorization": f"Bearer {PATREON_TOKEN}", + "Content-Type": "application/json" + } + + url = "https://www.patreon.com/api/oauth2/v2/campaigns" + + try: + # First get the campaign ID + campaign_response = requests.get(url, headers=headers) + campaign_response.raise_for_status() + campaign_data = campaign_response.json() + + if not campaign_data.get('data'): + print("No campaigns found for this Patreon account") + return [] + + campaign_id = campaign_data['data'][0]['id'] + + # Now get the supporters for this campaign + members_url = f"https://www.patreon.com/api/oauth2/v2/campaigns/{campaign_id}/members" + params = { + "include": "user", + "fields[member]": "full_name,is_follower,patron_status", # Removed profile_url + "fields[user]": "image_url" + } + + supporters = [] + while members_url: + members_response = requests.get(members_url, headers=headers, params=params) + members_response.raise_for_status() + members_data = members_response.json() + + # Process the response to extract active patrons + for member in members_data.get('data', []): + attributes = member.get('attributes', {}) + + # Only include active patrons + if attributes.get('patron_status') == 'active_patron': + name = attributes.get('full_name', 'Anonymous Supporter') + + # Get user data which contains the profile image + user_id = member.get('relationships', {}).get('user', {}).get('data', {}).get('id') + profile_image = None + profile_url = None # Removed profile_url since it's not supported + + if user_id: + for included in members_data.get('included', []): + if included.get('id') == user_id and included.get('type') == 'user': + profile_image = included.get('attributes', {}).get('image_url') + break + + supporters.append({ + 'name': name, + 'profile_image': profile_image, + 'profile_url': profile_url, # This will be None + 'platform': 'Patreon', + 'amount': 0 # Placeholder, as Patreon API doesn't provide this in the current response + }) + + # Handle pagination + members_url = members_data.get('links', {}).get('next') + + print(f"Found {len(supporters)} active Patreon supporters") + return supporters + + except requests.exceptions.RequestException as e: + print(f"Error fetching Patreon data: {e}") + print(f"Response content: {e.response.content if hasattr(e, 'response') else 'No response content'}") + return [] + +def fetch_github_sponsors(): + """Fetch current GitHub sponsors for a user or organization""" + print("Fetching GitHub sponsors...") + + headers = { + "Authorization": f"Bearer {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json" + } + + # Determine if we're fetching for a user or an organization + entity_type = "organization" if GITHUB_ORG else "user" + entity_name = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME + + if not entity_name: + print("Error: Neither GITHUB_USERNAME nor GITHUB_ORG is set") + return [] + + # Different GraphQL query structure based on entity type + if entity_type == "user": + query = """ + query { + user(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + nodes { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + monthlyPriceInDollars + } + isOneTimePayment + isActive + } + } + } + } + """ % entity_name + else: # organization + query = """ + query { + organization(login: "%s") { + sponsorshipsAsMaintainer(first: 100) { + nodes { + sponsorEntity { + ... on User { + login + name + avatarUrl + url + } + ... on Organization { + login + name + avatarUrl + url + } + } + tier { + monthlyPriceInDollars + } + isOneTimePayment + isActive + } + } + } + } + """ % entity_name + + try: + response = requests.post( + "https://api.github.com/graphql", + headers=headers, + json={"query": query} + ) + response.raise_for_status() + data = response.json() + + # Process the response - the path to the data differs based on entity type + if entity_type == "user": + sponsors_data = data.get('data', {}).get('user', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', []) + else: + sponsors_data = data.get('data', {}).get('organization', {}).get('sponsorshipsAsMaintainer', {}).get('nodes', []) + + sponsors = [] + for sponsor in sponsors_data: + # Only include active sponsors + if sponsor.get('isActive'): + entity = sponsor.get('sponsorEntity', {}) + name = entity.get('name') or entity.get('login', 'Anonymous Sponsor') + profile_image = entity.get('avatarUrl') + profile_url = entity.get('url') + amount = sponsor.get('tier', {}).get('monthlyPriceInDollars', 0) + + sponsors.append({ + 'name': name, + 'profile_image': profile_image, + 'profile_url': profile_url, + 'platform': 'GitHub Sponsors', + 'amount': amount + }) + + print(f"Found {len(sponsors)} active GitHub sponsors for {entity_type} '{entity_name}'") + return sponsors + + except requests.exceptions.RequestException as e: + print(f"Error fetching GitHub sponsors data: {e}") + return [] + +def generate_readme(supporters): + """Generate a README.md file with supporter information""" + print(f"Generating {README_PATH}...") + + # Sort supporters by amount (descending) and then by name + supporters.sort(key=lambda x: (-x['amount'], x['name'].lower())) + + # Determine the proper footer links based on what's configured + github_entity = GITHUB_ORG if GITHUB_ORG else GITHUB_USERNAME + github_entity_type = "orgs" if GITHUB_ORG else "sponsors" + github_sponsor_url = f"https://github.com/{github_entity_type}/{github_entity}" + + with open(README_PATH, "w", encoding="utf-8") as f: + f.write("## Support My Work\n\n") + f.write("If you enjoy my work, or use it for commercial purposes, please consider sponsoring me so I can continue to maintain it. Every bit helps! \n\n") + # Create appropriate call-to-action based on what's configured + cta_parts = [] + if github_entity: + cta_parts.append(f"[Become a sponsor on GitHub]({github_sponsor_url})") + if PATREON_TOKEN: + cta_parts.append("[support me on Patreon](https://www.patreon.com/ostris)") + + if cta_parts: + if GITHUB_ORG: + f.write(f"{' or '.join(cta_parts)}.\n\n") + f.write("Thank you to all my current supporters!\n\n") + + f.write(f"_Last updated: {datetime.now().strftime('%Y-%m-%d')}_\n\n") + + # Write GitHub Sponsors section + github_sponsors = [s for s in supporters if s['platform'] == 'GitHub Sponsors'] + if github_sponsors: + f.write("### GitHub Sponsors\n\n") + for sponsor in github_sponsors: + if sponsor['profile_image']: + f.write(f"\"{sponsor['name']}\" ") + else: + f.write(f"[{sponsor['name']}]({sponsor['profile_url']}) ") + f.write("\n\n") + + # Write Patreon section + patreon_supporters = [s for s in supporters if s['platform'] == 'Patreon'] + if patreon_supporters: + f.write("### Patreon Supporters\n\n") + for supporter in patreon_supporters: + if supporter['profile_image']: + f.write(f"\"{supporter['name']}\" ") + else: + f.write(f"[{supporter['name']}]({supporter['profile_url']}) ") + f.write("\n\n") + + f.write("\n---\n\n") + + + print(f"Successfully generated {README_PATH} with {len(supporters)} supporters!") + +def main(): + """Main function""" + print("Starting supporter data collection...") + + # Check if required environment variables are set + missing_vars = [] + if not GITHUB_TOKEN: + missing_vars.append("GITHUB_TOKEN") + + # Either username or org is required for GitHub + if not GITHUB_USERNAME and not GITHUB_ORG: + missing_vars.append("GITHUB_USERNAME or GITHUB_ORG") + + # Patreon token is optional but warn if missing + patreon_enabled = bool(PATREON_TOKEN) + + if missing_vars: + print(f"Error: Missing required environment variables: {', '.join(missing_vars)}") + print("Please add them to your .env file") + return + + if not patreon_enabled: + print("Warning: PATREON_ACCESS_TOKEN not set. Will only fetch GitHub sponsors.") + + # Fetch data from both platforms + patreon_supporters = fetch_patreon_supporters() if PATREON_TOKEN else [] + github_sponsors = fetch_github_sponsors() + + # Combine supporters from both platforms + all_supporters = patreon_supporters + github_sponsors + + if not all_supporters: + print("No supporters found on either platform") + return + + # Generate README + generate_readme(all_supporters) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/todo_multigpu.md b/todo_multigpu.md deleted file mode 100644 index 02d5abda..00000000 --- a/todo_multigpu.md +++ /dev/null @@ -1,3 +0,0 @@ -- only do ema on main device? shouldne be needed other than saving and sampling -- check when to unwrap model and what it does -- disable timer for non main local \ No newline at end of file From 1f3f45a48d5e1bfee11a9ad1d060598a4c6906f6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 3 Mar 2025 08:22:15 -0700 Subject: [PATCH 5/6] Bugfixes --- toolkit/lora_special.py | 5 +++-- toolkit/models/lokr.py | 1 + toolkit/network_mixins.py | 5 ++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index ace0a2ec..1b308cd4 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -344,8 +344,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if (is_linear or is_conv2d) and not skip: - if self.only_if_contains is not None and not any([word in clean_name for word in self.only_if_contains]): - continue + if self.only_if_contains is not None: + if not any([word in clean_name for word in self.only_if_contains]) and not any([word in lora_name for word in self.only_if_contains]): + continue dim = None alpha = None diff --git a/toolkit/models/lokr.py b/toolkit/models/lokr.py index 0240a97b..3d7e6ca6 100644 --- a/toolkit/models/lokr.py +++ b/toolkit/models/lokr.py @@ -101,6 +101,7 @@ class LokrModule(ToolkitModuleMixin, nn.Module): self.cp = False self.use_w1 = False self.use_w2 = False + self.can_merge_in = True self.shape = org_module.weight.shape if org_module.__class__.__name__ == 'Conv2d': diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index c9a73000..d2d1a500 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -634,7 +634,10 @@ class ToolkitNetworkMixin: # without having to set it in every single module every time it changes multiplier = self._multiplier # get first module - first_module = self.get_all_modules()[0] + try: + first_module = self.get_all_modules()[0] + except IndexError: + raise ValueError("There are not any lora modules in this network. Check your config and try again") if hasattr(first_module, 'lora_down'): device = first_module.lora_down.weight.device From c5e0c2bbe22889962f72005514234b6cd63df531 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 3 Mar 2025 16:27:19 -0700 Subject: [PATCH 6/6] Fixes to allow for redux assisted training --- toolkit/stable_diffusion_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 00b38574..7a2dcdff 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1343,7 +1343,8 @@ class StableDiffusion: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) self.adapter(conditional_clip_embeds) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \ + and gen_config.adapter_image_path is not None: # handle condition the prompts gen_config.prompt = self.adapter.condition_prompt( gen_config.prompt, @@ -1397,7 +1398,7 @@ class StableDiffusion: conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, prompt_embeds=conditional_embeds,