From 64a54418323192d6ba8288e93a8e50c17467c929 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 4 Sep 2023 14:05:10 -0600 Subject: [PATCH] Fully tested and now supporting locon on sdxl. If you have the ram --- jobs/process/BaseSDTrainProcess.py | 8 + toolkit/lora_special.py | 19 +- toolkit/lycoris_special.py | 313 ++++++++++++++++++++++++++--- toolkit/network_mixins.py | 13 +- toolkit/stable_diffusion_model.py | 6 +- 5 files changed, 320 insertions(+), 39 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 16123fb6..e532ef31 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,6 +5,7 @@ from collections import OrderedDict import os from typing import Union +from lycoris.config import PRESET from torch.utils.data import DataLoader from toolkit.data_loader import get_dataloader_from_datasets @@ -468,12 +469,19 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? + is_lycoris = False # default to LoCON if there are any conv layers or if it is named NetworkClass = LoRASpecialNetwork if self.network_config.conv is not None and self.network_config.conv > 0: NetworkClass = LycorisSpecialNetwork + is_lycoris = True if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': NetworkClass = LycorisSpecialNetwork + is_lycoris = True + + if is_lycoris: + preset = PRESET['full'] + # NetworkClass.apply_preset(preset) self.network = NetworkClass( text_encoder=text_encoder, diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 6fd2cf92..c465f42f 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -3,15 +3,13 @@ import math import os import re import sys -from collections import OrderedDict from typing import List, Optional, Dict, Type, Union import torch from transformers import CLIPTextModel from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin -from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT -from .train_tools import get_torch_dtype +from .paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT) @@ -22,6 +20,17 @@ from torch.utils.checkpoint import checkpoint RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] + class LoRAModule(ToolkitModuleMixin, torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -197,8 +206,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_linear = child_module.__class__.__name__.in_(LINEAR_MODULES) + is_conv2d = child_module.__class__.__name__.in_(CONV_MODULES) is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index b16158ec..c71de612 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -1,38 +1,147 @@ +import math import os from typing import Optional, Union, List, Type +import torch from lycoris.kohya import LycorisNetwork, LoConModule +from lycoris.modules.glora import GLoRAModule from torch import nn from transformers import CLIPTextModel - +from torch.nn import functional as F from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin +# diffusers specific stuff +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' + # 'GroupNorm', +] +CONV_MODULES = [ + 'Conv2d', + 'LoRACompatibleConv' +] class LoConSpecialModule(ToolkitModuleMixin, LoConModule): def __init__( self, - lora_name, - org_module: nn.Module, + lora_name, org_module: nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=0., rank_dropout=0., module_dropout=0., use_cp=False, **kwargs, ): + """ if alpha == 0 or None, alpha is rank (no scaling). """ + # call super of super().__init__( - lora_name, - org_module, - multiplier=multiplier, - lora_dim=lora_dim, alpha=alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - use_cp=use_cp, - **kwargs, + call_super_init=False, + **kwargs ) + # call super of super + super(LoConModule, self).__init__() + + self.lora_name = lora_name + self.lora_dim = lora_dim + self.cp = False + + self.scalar = nn.Parameter(torch.tensor(0.0)) + orig_module_name = org_module.__class__.__name__ + if orig_module_name in CONV_MODULES: + self.isconv = True + # For general LoCon + in_dim = org_module.in_channels + k_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + out_dim = org_module.out_channels + self.down_op = F.conv2d + self.up_op = F.conv2d + if use_cp and k_size != (1, 1): + self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False) + self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False) + self.cp = True + else: + self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False) + self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + elif orig_module_name in LINEAR_MODULES: + self.isconv = False + self.down_op = F.linear + self.up_op = F.linear + if orig_module_name == 'GroupNorm': + # RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32) + in_dim = org_module.num_channels + out_dim = org_module.num_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + else: + raise NotImplementedError + self.shape = org_module.weight.shape + + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = nn.Identity() + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lora_up.weight) + if self.cp: + torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5)) + + self.multiplier = multiplier + self.org_module = [org_module] + self.register_load_state_dict_post_hook(self.load_weight_hook) + + def load_weight_hook(self, *args, **kwargs): + self.scalar = nn.Parameter(torch.ones_like(self.scalar)) class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): + UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "ResnetBlock2D", + "Downsample2D", + "Upsample2D", + # 'UNet2DConditionModel', + # 'Conv2d', + # 'Timesteps', + # 'TimestepEmbedding', + # 'Linear', + # 'SiLU', + # 'ModuleList', + # 'DownBlock2D', + 'ResnetBlock2D', # need + # 'GroupNorm', + # 'LoRACompatibleConv', + # 'LoRACompatibleLinear', + # 'Dropout', + # 'CrossAttnDownBlock2D', # needed + 'Transformer2DModel', # maybe not, has duplicates + # 'BasicTransformerBlock', # duplicates + # 'LayerNorm', + # 'Attention', + # 'FeedForward', + # 'GEGLU', + # 'UpBlock2D', + # 'UNetMidBlock2DCrossAttn' + ] + UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", + ] def __init__( self, text_encoder: Union[List[CLIPTextModel], CLIPTextModel], @@ -49,6 +158,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): network_module: Type[object] = LoConSpecialModule, **kwargs, ) -> None: + # call ToolkitNetworkMixin super + super().__init__( + **kwargs + ) + # call the parent of the parent LycorisNetwork + super(LycorisNetwork, self).__init__() + # LyCORIS unique stuff if dropout is None: dropout = 0 @@ -57,19 +173,162 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): if module_dropout is None: module_dropout = 0 - super().__init__( - text_encoder, - unet, - multiplier=multiplier, - lora_dim=lora_dim, - conv_lora_dim=conv_lora_dim, - alpha=alpha, - conv_alpha=conv_alpha, - use_cp=use_cp, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, - network_module=network_module, - **kwargs, - ) + self.multiplier = multiplier + self.lora_dim = lora_dim + if not self.ENABLE_CONV: + conv_lora_dim = 0 + + self.conv_lora_dim = int(conv_lora_dim) + if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim: + print('Apply different lora dim for conv layer') + print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}') + elif self.conv_lora_dim == 0: + print('Disable conv layer') + + self.alpha = alpha + self.conv_alpha = float(conv_alpha) + if self.conv_lora_dim and self.alpha != self.conv_alpha: + print('Apply different alpha value for conv layer') + print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}') + + if 1 >= dropout >= 0: + print(f'Use Dropout value: {dropout}') + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + # create module instances + def create_modules( + prefix, + root_module: torch.nn.Module, + target_replace_modules, + target_replace_names=[] + ) -> List[network_module]: + print('Create LyCORIS Module') + loras = [] + # remove this + named_modules = root_module.named_modules() + modules = root_module.modules() + # add a few to tthe generator + + for name, module in named_modules: + module_name = module.__class__.__name__ + if module_name in target_replace_modules: + if module_name in self.MODULE_ALGO_MAP: + algo = self.MODULE_ALGO_MAP[module_name] + else: + algo = network_module + for child_name, child_module in module.named_modules(): + lora_name = prefix + '.' + name + '.' + child_name + lora_name = lora_name.replace('.', '_') + if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'): + print(f"{lora_name}") + + if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + **kwargs + ) + elif child_module.__class__.__name__ in CONV_MODULES: + k_size, *_ = child_module.kernel_size + if k_size == 1 and lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + **kwargs + ) + elif conv_lora_dim > 0: + lora = algo( + lora_name, child_module, self.multiplier, + self.conv_lora_dim, self.conv_alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + **kwargs + ) + else: + continue + else: + continue + loras.append(lora) + elif name in target_replace_names: + if name in self.NAME_ALGO_MAP: + algo = self.NAME_ALGO_MAP[name] + else: + algo = network_module + lora_name = prefix + '.' + name + lora_name = lora_name.replace('.', '_') + if module.__class__.__name__ == 'Linear' and lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + **kwargs + ) + elif module.__class__.__name__ == 'Conv2d': + k_size, *_ = module.kernel_size + if k_size == 1 and lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.lora_dim, self.alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + **kwargs + ) + elif conv_lora_dim > 0: + lora = algo( + lora_name, module, self.multiplier, + self.conv_lora_dim, self.conv_alpha, + self.dropout, self.rank_dropout, self.module_dropout, + use_cp, + **kwargs + ) + else: + continue + else: + continue + loras.append(lora) + return loras + + if network_module == GLoRAModule: + print('GLoRA enabled, only train transformer') + # only train transformer (for GLoRA) + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + "Attention", + ] + LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = [] + + if isinstance(text_encoder, list): + text_encoders = text_encoder + use_index = True + else: + text_encoders = [text_encoder] + use_index = False + + self.text_encoder_loras = [] + for i, te in enumerate(text_encoders): + self.text_encoder_loras.extend(create_modules( + LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''), + te, + LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE + )) + print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet, + LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE) + print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.") + + self.weights_sd = None + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index f3979cb2..beb6596a 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -21,9 +21,11 @@ class ToolkitModuleMixin: def __init__( self: Module, *args, + call_super_init: bool = True, **kwargs ): - super().__init__(*args, **kwargs) + if call_super_init: + super().__init__(*args, **kwargs) self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 @@ -74,7 +76,10 @@ class ToolkitModuleMixin: if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp: lx = self.lora_mid(self.lora_down(x)) else: - lx = self.lora_down(x) + try: + lx = self.lora_down(x) + except RuntimeError as e: + print(f"Error in {self.__class__.__name__} lora_down") if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): lx = self.dropout(lx) @@ -202,7 +207,7 @@ class ToolkitNetworkMixin: self._is_normalizing: bool = False self.is_sdxl = is_sdxl self.is_v2 = is_v2 - super().__init__(*args, **kwargs) + # super().__init__(*args, **kwargs) def get_keymap(self: Network): if self.is_sdxl: @@ -219,7 +224,7 @@ class ToolkitNetworkMixin: # check if file exists if os.path.exists(keymap_path): with open(keymap_path, 'r') as f: - keymap = json.load(f) + keymap = json.load(f)['ldm_diffusers_keymap'] return keymap diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 878fca19..de2259bd 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -367,9 +367,9 @@ class StableDiffusion: # was trained on 0.7 (I believe) grs = gen_config.guidance_rescale - # if grs is None or grs < 0.00001: - # grs = 0.7 - grs = 0.0 + if grs is None or grs < 0.00001: + grs = 0.7 + # grs = 0.0 extra = {} if sampler.startswith("sample_"):