diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 25ad1b6c..350f59bf 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -93,6 +93,7 @@ class SDTrainer(BaseSDTrainProcess): # back propagate loss to free ram loss.backward() + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) flush() # apply gradients diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 0e2a94ad..16123fb6 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,5 +1,6 @@ import copy import glob +import inspect from collections import OrderedDict import os from typing import Union @@ -10,6 +11,8 @@ from toolkit.data_loader import get_dataloader_from_datasets from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding from toolkit.lora_special import LoRASpecialNetwork +from toolkit.lycoris_special import LycorisSpecialNetwork +from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT from toolkit.sampler import get_sampler @@ -74,6 +77,7 @@ class BaseSDTrainProcess(BaseTrainProcess): raw_datasets = preprocess_dataset_raw_config(raw_datasets) self.datasets = None self.datasets_reg = None + self.params = [] if raw_datasets is not None and len(raw_datasets) > 0: for raw_dataset in raw_datasets: dataset = DatasetConfig(**raw_dataset) @@ -120,7 +124,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) # to hold network if there is one - self.network = None + self.network: Union[Network, None] = None self.embedding = None def sample(self, step=None, is_first=False): @@ -424,25 +428,54 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_scheduler = self.sd.noise_scheduler if self.train_config.xformers: - vae.set_use_memory_efficient_attention_xformers(True) + vae.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention() + if isinstance(text_encoder, list): + for te in text_encoder: + # if it has it + if hasattr(te, 'enable_xformers_memory_efficient_attention'): + te.enable_xformers_memory_efficient_attention() + if self.train_config.gradient_checkpointing: unet.enable_gradient_checkpointing() - # if isinstance(text_encoder, list): - # for te in text_encoder: - # te.enable_gradient_checkpointing() - # else: - # text_encoder.enable_gradient_checkpointing() + if isinstance(text_encoder, list): + for te in text_encoder: + if hasattr(te, 'enable_gradient_checkpointing'): + te.enable_gradient_checkpointing() + if hasattr(te, "gradient_checkpointing_enable"): + te.gradient_checkpointing_enable() + else: + if hasattr(text_encoder, 'enable_gradient_checkpointing'): + text_encoder.enable_gradient_checkpointing() + if hasattr(text_encoder, "gradient_checkpointing_enable"): + text_encoder.gradient_checkpointing_enable() + if isinstance(text_encoder, list): + for te in text_encoder: + te.requires_grad_(False) + te.eval() + else: + text_encoder.requires_grad_(False) + text_encoder.eval() unet.to(self.device_torch, dtype=dtype) unet.requires_grad_(False) unet.eval() vae = vae.to(torch.device('cpu'), dtype=dtype) vae.requires_grad_(False) vae.eval() + flush() if self.network_config is not None: - self.network = LoRASpecialNetwork( + # TODO should we completely switch to LycorisSpecialNetwork? + + # default to LoCON if there are any conv layers or if it is named + NetworkClass = LoRASpecialNetwork + if self.network_config.conv is not None and self.network_config.conv > 0: + NetworkClass = LycorisSpecialNetwork + if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': + NetworkClass = LycorisSpecialNetwork + + self.network = NetworkClass( text_encoder=text_encoder, unet=unet, lora_dim=self.network_config.linear, @@ -468,14 +501,21 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.network.prepare_grad_etc(text_encoder, unet) + flush() params = self.get_params() if not params: + # LyCORIS doesnt have default_lr + config = { + 'text_encoder_lr': self.train_config.lr, + 'unet_lr': self.train_config.lr, + } + sig = inspect.signature(self.network.prepare_optimizer_params) + if 'default_lr' in sig.parameters: + config['default_lr'] = self.train_config.lr params = self.network.prepare_optimizer_params( - text_encoder_lr=self.train_config.lr, - unet_lr=self.train_config.lr, - default_lr=self.train_config.lr + **config ) if self.train_config.gradient_checkpointing: @@ -490,6 +530,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print(f"Loading from {latest_save_path}") self.load_weights(latest_save_path) self.network.multiplier = 1.0 + + flush() elif self.embed_config is not None: self.embedding = Embedding( sd=self.sd, @@ -508,7 +550,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if not params: # set trainable params params = self.embedding.get_trainable_params() - + flush() else: # set them to train or not if self.train_config.train_unet: @@ -546,9 +588,16 @@ class BaseSDTrainProcess(BaseTrainProcess): unet_lr=self.train_config.lr, default_lr=self.train_config.lr ) - + flush() ### HOOK ### params = self.hook_add_extra_train_params(params) + self.params = [] + + for param in params: + if isinstance(param, dict): + self.params += param['params'] + else: + self.params.append(param) optimizer_type = self.train_config.optimizer.lower() optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr, @@ -568,6 +617,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.lr_scheduler = lr_scheduler + flush() ### HOOK ### self.hook_before_train_loop() @@ -639,7 +689,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # turn on normalization if we are using it and it is not on if self.network is not None and self.network_config.normalize and not self.network.is_normalizing: self.network.is_normalizing = True - + flush() ### HOOK ### loss_dict = self.hook_train_loop(batch) flush() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3c44d31a..f3ea686b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -37,9 +37,11 @@ class SampleConfig: self.ext: ImgExt = kwargs.get('format', 'jpg') +NetworkType = Literal['lora', 'locon'] + class NetworkConfig: def __init__(self, **kwargs): - self.type: str = kwargs.get('type', 'lora') + self.type: NetworkType = kwargs.get('type', 'lora') rank = kwargs.get('rank', None) linear = kwargs.get('linear', None) if rank is not None: @@ -86,6 +88,7 @@ class TrainConfig: self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) self.weight_jitter = kwargs.get('weight_jitter', 0.0) self.merge_network_on_save = kwargs.get('merge_network_on_save', False) + self.max_grad_norm = kwargs.get('max_grad_norm', 1.0) class ModelConfig: diff --git a/toolkit/lora.py b/toolkit/lora.py deleted file mode 100644 index 0780cbc3..00000000 --- a/toolkit/lora.py +++ /dev/null @@ -1,243 +0,0 @@ -# ref: -# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py -# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py -# - https://github.com/p1atdev/LECO/blob/main/lora.py - -import os -import math -from typing import Optional, List, Type, Set, Literal -from collections import OrderedDict - -import torch -import torch.nn as nn -from diffusers import UNet2DConditionModel -from safetensors.torch import save_file - -from toolkit.metadata import add_model_hash_to_meta - -UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [ - "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2 -] -UNET_TARGET_REPLACE_MODULE_CONV = [ - "ResnetBlock2D", - "Downsample2D", - "Upsample2D", -] # locon, 3clier - -LORA_PREFIX_UNET = "lora_unet" - -DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER - -TRAINING_METHODS = Literal[ - "noxattn", # train all layers except x-attns and time_embed layers - "innoxattn", # train all layers except self attention layers - "selfattn", # ESD-u, train only self attention layers - "xattn", # ESD-x, train only x attention layers - "full", # train all layers - # "notime", - # "xlayer", - # "outxattn", - # "outsattn", - # "inxattn", - # "inmidsattn", - # "selflayer", -] - - -class LoRAModule(nn.Module): - """ - replaces forward method of the original Linear, instead of replacing the original Linear module. - """ - - def __init__( - self, - lora_name, - org_module: nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - ): - """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() - self.lora_name = lora_name - self.lora_dim = lora_dim - - if org_module.__class__.__name__ == "Linear": - in_dim = org_module.in_features - out_dim = org_module.out_features - self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) - - elif org_module.__class__.__name__ == "Conv2d": # 一応 - in_dim = org_module.in_channels - out_dim = org_module.out_channels - - self.lora_dim = min(self.lora_dim, in_dim, out_dim) - if self.lora_dim != lora_dim: - print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") - - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = nn.Conv2d( - in_dim, self.lora_dim, kernel_size, stride, padding, bias=False - ) - self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().numpy() - alpha = lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える - - # same as microsoft's - nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - nn.init.zeros_(self.lora_up.weight) - - self.multiplier = multiplier - self.org_module = org_module # remove in applying - - def apply_to(self): - self.org_forward = self.org_module.forward - self.org_module.forward = self.forward - del self.org_module - - def forward(self, x): - return ( - self.org_forward(x) - + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale - ) - - -class LoRANetwork(nn.Module): - def __init__( - self, - unet: UNet2DConditionModel, - rank: int = 4, - multiplier: float = 1.0, - alpha: float = 1.0, - train_method: TRAINING_METHODS = "full", - ) -> None: - super().__init__() - - self.multiplier = multiplier - self.lora_dim = rank - self.alpha = alpha - - # LoRAのみ - self.module = LoRAModule - - # unetのloraを作る - self.unet_loras = self.create_modules( - LORA_PREFIX_UNET, - unet, - DEFAULT_TARGET_REPLACE, - self.lora_dim, - self.multiplier, - train_method=train_method, - ) - print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") - - # assertion 名前の被りがないか確認しているようだ - lora_names = set() - for lora in self.unet_loras: - assert ( - lora.lora_name not in lora_names - ), f"duplicated lora name: {lora.lora_name}. {lora_names}" - lora_names.add(lora.lora_name) - - # 適用する - for lora in self.unet_loras: - lora.apply_to() - self.add_module( - lora.lora_name, - lora, - ) - - del unet - - torch.cuda.empty_cache() - - def create_modules( - self, - prefix: str, - root_module: nn.Module, - target_replace_modules: List[str], - rank: int, - multiplier: float, - train_method: TRAINING_METHODS, - ) -> list: - loras = [] - - for name, module in root_module.named_modules(): - if train_method == "noxattn": # Cross Attention と Time Embed 以外学習 - if "attn2" in name or "time_embed" in name: - continue - elif train_method == "innoxattn": # Cross Attention 以外学習 - if "attn2" in name: - continue - elif train_method == "selfattn": # Self Attention のみ学習 - if "attn1" not in name: - continue - elif train_method == "xattn": # Cross Attention のみ学習 - if "attn2" not in name: - continue - elif train_method == "full": # 全部学習 - pass - else: - raise NotImplementedError( - f"train_method: {train_method} is not implemented." - ) - if module.__class__.__name__ in target_replace_modules: - for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ in ["Linear", "Conv2d"]: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - print(f"{lora_name}") - lora = self.module( - lora_name, child_module, multiplier, rank, self.alpha - ) - loras.append(lora) - - return loras - - def prepare_optimizer_params(self): - all_params = [] - - if self.unet_loras: # 実質これしかない - params = [] - [params.extend(lora.parameters()) for lora in self.unet_loras] - param_data = {"params": params} - all_params.append(param_data) - - return all_params - - def save_weights(self, file, dtype=None, metadata: Optional[dict] = None): - state_dict = self.state_dict() - if metadata is None: - metadata = OrderedDict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - state_dict[key] = v - - for key in list(state_dict.keys()): - if not key.startswith("lora"): - # remove any not lora - del state_dict[key] - - metadata = add_model_hash_to_meta(state_dict, metadata) - if os.path.splitext(file)[1] == ".safetensors": - save_file(state_dict, file, metadata) - else: - torch.save(state_dict, file) - - def __enter__(self): - for lora in self.unet_loras: - lora.multiplier = 1.0 - - def __exit__(self, exc_type, exc_value, tb): - for lora in self.unet_loras: - lora.multiplier = 0 diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index a4012060..6fd2cf92 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -9,6 +9,7 @@ from typing import List, Optional, Dict, Type, Union import torch from transformers import CLIPTextModel +from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT from .train_tools import get_torch_dtype @@ -21,7 +22,7 @@ from torch.utils.checkpoint import checkpoint RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") -class LoRAModule(torch.nn.Module): +class LoRAModule(ToolkitModuleMixin, torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ @@ -40,6 +41,7 @@ class LoRAModule(torch.nn.Module): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name + self.scalar = torch.tensor(1.0) if org_module.__class__.__name__ == "Conv2d": in_dim = org_module.in_channels @@ -89,153 +91,8 @@ class LoRAModule(torch.nn.Module): self.org_module.forward = self.forward del self.org_module - # this allows us to set different multipliers on a per item in a batch basis - # allowing us to run positive and negative weights in the same batch - # really only useful for slider training for now - def get_multiplier(self, lora_up): - with torch.no_grad(): - batch_size = lora_up.size(0) - # batch will have all negative prompts first and positive prompts second - # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts - # if there is more than our multiplier, it is likely a batch size increase, so we need to - # interleave the multipliers - if isinstance(self.multiplier, list): - if len(self.multiplier) == 0: - # single item, just return it - return self.multiplier[0] - elif len(self.multiplier) == batch_size: - # not doing CFG - multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) - else: - # we have a list of multipliers, so we need to get the multiplier for this batch - multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) - # should be 1 for if total batch size was 1 - num_interleaves = (batch_size // 2) // len(self.multiplier) - multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) - - # match lora_up rank - if len(lora_up.size()) == 2: - multiplier_tensor = multiplier_tensor.view(-1, 1) - elif len(lora_up.size()) == 3: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1) - elif len(lora_up.size()) == 4: - multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) - return multiplier_tensor.detach() - - else: - return self.multiplier - - def _call_forward(self, x): - # module dropout - if self.module_dropout is not None and self.training: - if torch.rand(1) < self.module_dropout: - return 0.0 # added to original forward - - lx = self.lora_down(x) - - # normal dropout - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - - # rank dropout - if self.rank_dropout is not None and self.training: - mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d - lx = lx * mask - - # scaling for rank dropout: treat as if the rank is changed - # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる - scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability - else: - scale = self.scale - - lx = self.lora_up(lx) - - return lx * scale - - def forward(self, x): - org_forwarded = self.org_forward(x) - lora_output = self._call_forward(x) - multiplier = self.get_multiplier(lora_output) - - if self.is_normalizing: - with torch.no_grad(): - - # do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier - if isinstance(multiplier, torch.Tensor): - norm_multiplier = multiplier.clone().detach() * 10 - norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0) - else: - norm_multiplier = multiplier - - # get a dim array from orig forward that had index of all dimensions except the batch and channel - - # Calculate the target magnitude for the combined output - orig_max = torch.max(torch.abs(org_forwarded)) - - # Calculate the additional increase in magnitude that lora_output would introduce - potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded)) - - epsilon = 1e-6 # Small constant to avoid division by zero - - # Calculate the scaling factor for the lora_output - # to ensure that the potential increase in magnitude doesn't change the original max - normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) - normalize_scaler = normalize_scaler.detach() - - # save the scaler so it can be applied later - self.normalize_scaler = normalize_scaler.clone().detach() - - lora_output *= normalize_scaler - - return org_forwarded + (lora_output * multiplier) - - def enable_gradient_checkpointing(self): - self.is_checkpointing = True - - def disable_gradient_checkpointing(self): - self.is_checkpointing = False - - @torch.no_grad() - def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): - """ - Applied the previous normalization calculation to the module. - This must be called before saving or normalization will be lost. - It is probably best to call after each batch as well. - We just scale the up down weights to match this vector - :return: - """ - # get state dict - state_dict = self.state_dict() - dtype = state_dict['lora_up.weight'].dtype - device = state_dict['lora_up.weight'].device - - # todo should we do this at fp32? - if isinstance(self.normalize_scaler, torch.Tensor): - scaler = self.normalize_scaler.clone().detach() - else: - scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype) - - total_module_scale = scaler / target_normalize_scaler - num_modules_layers = 2 # up and down - up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ - .to(device, dtype=dtype) - - # apply the scaler to the up and down weights - for key in state_dict.keys(): - if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): - # do it inplace do params are updated - state_dict[key] *= up_down_scale - - # reset the normalization scaler - self.normalize_scaler = target_normalize_scaler - - -class LoRASpecialNetwork(LoRANetwork): +class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] @@ -445,154 +302,3 @@ class LoRASpecialNetwork(LoRANetwork): for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - - def get_keymap(self): - if self.is_sdxl: - keymap_tail = 'sdxl' - elif self.is_v2: - keymap_tail = 'sd2' - else: - keymap_tail = 'sd1' - # load keymap - keymap_name = f"stable_diffusion_locon_{keymap_tail}.json" - - keymap = None - # check if file exists - if os.path.exists(keymap_name): - with open(keymap_name, 'r') as f: - keymap = json.load(f) - - return keymap - - def save_weights(self, file, dtype, metadata): - keymap = self.get_keymap() - - save_keymap = {} - if keymap is not None: - for ldm_key, diffusers_key in keymap.items(): - # invert them - save_keymap[diffusers_key] = ldm_key - - if metadata is not None and len(metadata) == 0: - metadata = None - - state_dict = self.state_dict() - save_dict = OrderedDict() - - if dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(dtype) - save_key = save_keymap[key] if key in save_keymap else key - save_dict[save_key] = v - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import save_file - save_file(save_dict, file, metadata) - else: - torch.save(save_dict, file) - - def load_weights(self, file): - # allows us to save and load to and from ldm weights - keymap = self.get_keymap() - keymap = {} if keymap is None else keymap - - if os.path.splitext(file)[1] == ".safetensors": - from safetensors.torch import load_file - - weights_sd = load_file(file) - else: - weights_sd = torch.load(file, map_location="cpu") - - load_sd = OrderedDict() - for key, value in weights_sd.items(): - load_key = keymap[key] if key in keymap else key - load_sd[load_key] = value - - info = self.load_state_dict(load_sd, False) - return info - - @property - def multiplier(self) -> Union[float, List[float]]: - return self._multiplier - - @multiplier.setter - def multiplier(self, value: Union[float, List[float]]): - self._multiplier = value - self._update_lora_multiplier() - - def _update_lora_multiplier(self): - - if self.is_active: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.multiplier = self._multiplier - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.multiplier = self._multiplier - else: - if hasattr(self, 'unet_loras'): - for lora in self.unet_loras: - lora.multiplier = 0 - if hasattr(self, 'text_encoder_loras'): - for lora in self.text_encoder_loras: - lora.multiplier = 0 - - # called when the context manager is entered - # ie: with network: - def __enter__(self): - self.is_active = True - self._update_lora_multiplier() - - def __exit__(self, exc_type, exc_value, tb): - self.is_active = False - self._update_lora_multiplier() - - def force_to(self, device, dtype): - self.to(device, dtype) - loras = [] - if hasattr(self, 'unet_loras'): - loras += self.unet_loras - if hasattr(self, 'text_encoder_loras'): - loras += self.text_encoder_loras - for lora in loras: - lora.to(device, dtype) - - def get_all_modules(self): - loras = [] - if hasattr(self, 'unet_loras'): - loras += self.unet_loras - if hasattr(self, 'text_encoder_loras'): - loras += self.text_encoder_loras - return loras - - def _update_checkpointing(self): - for module in self.get_all_modules(): - if self.is_checkpointing: - module.enable_gradient_checkpointing() - else: - module.disable_gradient_checkpointing() - - def enable_gradient_checkpointing(self): - # not supported - self.is_checkpointing = True - self._update_checkpointing() - - def disable_gradient_checkpointing(self): - # not supported - self.is_checkpointing = False - self._update_checkpointing() - - @property - def is_normalizing(self) -> bool: - return self._is_normalizing - - @is_normalizing.setter - def is_normalizing(self, value: bool): - self._is_normalizing = value - for module in self.get_all_modules(): - module.is_normalizing = self._is_normalizing - - def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): - for module in self.get_all_modules(): - module.apply_stored_normalizer(target_normalize_scaler) diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py new file mode 100644 index 00000000..b16158ec --- /dev/null +++ b/toolkit/lycoris_special.py @@ -0,0 +1,75 @@ +import os +from typing import Optional, Union, List, Type + +from lycoris.kohya import LycorisNetwork, LoConModule +from torch import nn +from transformers import CLIPTextModel + +from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin + + +class LoConSpecialModule(ToolkitModuleMixin, LoConModule): + def __init__( + self, + lora_name, + org_module: nn.Module, + multiplier=1.0, + lora_dim=4, alpha=1, + dropout=0., rank_dropout=0., module_dropout=0., + use_cp=False, + **kwargs, + ): + super().__init__( + lora_name, + org_module, + multiplier=multiplier, + lora_dim=lora_dim, alpha=alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + use_cp=use_cp, + **kwargs, + ) + + +class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + use_cp: Optional[bool] = False, + network_module: Type[object] = LoConSpecialModule, + **kwargs, + ) -> None: + # LyCORIS unique stuff + if dropout is None: + dropout = 0 + if rank_dropout is None: + rank_dropout = 0 + if module_dropout is None: + module_dropout = 0 + + super().__init__( + text_encoder, + unet, + multiplier=multiplier, + lora_dim=lora_dim, + conv_lora_dim=conv_lora_dim, + alpha=alpha, + conv_alpha=conv_alpha, + use_cp=use_cp, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + network_module=network_module, + **kwargs, + ) + diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py new file mode 100644 index 00000000..f3979cb2 --- /dev/null +++ b/toolkit/network_mixins.py @@ -0,0 +1,358 @@ +import json +import os +from collections import OrderedDict +from typing import Optional, Union, List, Type, TYPE_CHECKING + +import torch +from torch import nn + +from toolkit.metadata import add_model_hash_to_meta +from toolkit.paths import KEYMAPS_ROOT + +if TYPE_CHECKING: + from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule + from toolkit.lora_special import LoRASpecialNetwork, LoRAModule + +Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] +Module = Union['LoConSpecialModule', 'LoRAModule'] + + +class ToolkitModuleMixin: + def __init__( + self: Module, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.is_checkpointing = False + self.is_normalizing = False + self.normalize_scaler = 1.0 + + # this allows us to set different multipliers on a per item in a batch basis + # allowing us to run positive and negative weights in the same batch + # really only useful for slider training for now + def get_multiplier(self: Module, lora_up): + with torch.no_grad(): + batch_size = lora_up.size(0) + # batch will have all negative prompts first and positive prompts second + # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts + # if there is more than our multiplier, it is likely a batch size increase, so we need to + # interleave the multipliers + if isinstance(self.multiplier, list): + if len(self.multiplier) == 0: + # single item, just return it + return self.multiplier[0] + elif len(self.multiplier) == batch_size: + # not doing CFG + multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) + else: + + # we have a list of multipliers, so we need to get the multiplier for this batch + multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) + # should be 1 for if total batch size was 1 + num_interleaves = (batch_size // 2) // len(self.multiplier) + multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) + + # match lora_up rank + if len(lora_up.size()) == 2: + multiplier_tensor = multiplier_tensor.view(-1, 1) + elif len(lora_up.size()) == 3: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1) + elif len(lora_up.size()) == 4: + multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) + return multiplier_tensor.detach() + + else: + return self.multiplier + + def _call_forward(self: Module, x): + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return 0.0 # added to original forward + + if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp: + lx = self.lora_mid(self.lora_down(x)) + else: + lx = self.lora_down(x) + + if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): + lx = self.dropout(lx) + # normal dropout + elif self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.rank_dropout > 0 and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + scale *= self.scalar + + return lx * scale + + def forward(self: Module, x): + org_forwarded = self.org_forward(x) + lora_output = self._call_forward(x) + multiplier = self.get_multiplier(lora_output) + + if self.is_normalizing: + with torch.no_grad(): + + # do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier + if isinstance(multiplier, torch.Tensor): + norm_multiplier = multiplier.clone().detach() * 10 + norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0) + else: + norm_multiplier = multiplier + + # get a dim array from orig forward that had index of all dimensions except the batch and channel + + # Calculate the target magnitude for the combined output + orig_max = torch.max(torch.abs(org_forwarded)) + + # Calculate the additional increase in magnitude that lora_output would introduce + potential_max_increase = torch.max( + torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded)) + + epsilon = 1e-6 # Small constant to avoid division by zero + + # Calculate the scaling factor for the lora_output + # to ensure that the potential increase in magnitude doesn't change the original max + normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) + normalize_scaler = normalize_scaler.detach() + + # save the scaler so it can be applied later + self.normalize_scaler = normalize_scaler.clone().detach() + + lora_output *= normalize_scaler + + return org_forwarded + (lora_output * multiplier) + + def enable_gradient_checkpointing(self: Module): + self.is_checkpointing = True + + def disable_gradient_checkpointing(self: Module): + self.is_checkpointing = False + + @torch.no_grad() + def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0): + """ + Applied the previous normalization calculation to the module. + This must be called before saving or normalization will be lost. + It is probably best to call after each batch as well. + We just scale the up down weights to match this vector + :return: + """ + # get state dict + state_dict = self.state_dict() + dtype = state_dict['lora_up.weight'].dtype + device = state_dict['lora_up.weight'].device + + # todo should we do this at fp32? + if isinstance(self.normalize_scaler, torch.Tensor): + scaler = self.normalize_scaler.clone().detach() + else: + scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype) + + total_module_scale = scaler / target_normalize_scaler + num_modules_layers = 2 # up and down + up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ + .to(device, dtype=dtype) + + # apply the scaler to the up and down weights + for key in state_dict.keys(): + if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): + # do it inplace do params are updated + state_dict[key] *= up_down_scale + + # reset the normalization scaler + self.normalize_scaler = target_normalize_scaler + + +class ToolkitNetworkMixin: + def __init__( + self: Network, + *args, + train_text_encoder: Optional[bool] = True, + train_unet: Optional[bool] = True, + is_sdxl=False, + is_v2=False, + **kwargs + ): + self.train_text_encoder = train_text_encoder + self.train_unet = train_unet + self.is_checkpointing = False + self._multiplier: float = 1.0 + self.is_active: bool = False + self._is_normalizing: bool = False + self.is_sdxl = is_sdxl + self.is_v2 = is_v2 + super().__init__(*args, **kwargs) + + def get_keymap(self: Network): + if self.is_sdxl: + keymap_tail = 'sdxl' + elif self.is_v2: + keymap_tail = 'sd2' + else: + keymap_tail = 'sd1' + # load keymap + keymap_name = f"stable_diffusion_locon_{keymap_tail}.json" + keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name) + + keymap = None + # check if file exists + if os.path.exists(keymap_path): + with open(keymap_path, 'r') as f: + keymap = json.load(f) + + return keymap + + def save_weights(self: Network, file, dtype=torch.float16, metadata=None): + keymap = self.get_keymap() + + save_keymap = {} + if keymap is not None: + for ldm_key, diffusers_key in keymap.items(): + # invert them + save_keymap[diffusers_key] = ldm_key + + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + save_dict = OrderedDict() + + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + save_key = save_keymap[key] if key in save_keymap else key + save_dict[save_key] = v + + if metadata is None: + metadata = OrderedDict() + metadata = add_model_hash_to_meta(state_dict, metadata) + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + save_file(save_dict, file, metadata) + else: + torch.save(save_dict, file) + + def load_weights(self: Network, file): + # allows us to save and load to and from ldm weights + keymap = self.get_keymap() + keymap = {} if keymap is None else keymap + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + load_sd = OrderedDict() + for key, value in weights_sd.items(): + load_key = keymap[key] if key in keymap else key + load_sd[load_key] = value + + info = self.load_state_dict(load_sd, False) + return info + + @property + def multiplier(self) -> Union[float, List[float]]: + return self._multiplier + + @multiplier.setter + def multiplier(self, value: Union[float, List[float]]): + self._multiplier = value + self._update_lora_multiplier() + + def _update_lora_multiplier(self: Network): + if self.is_active: + if hasattr(self, 'unet_loras'): + for lora in self.unet_loras: + lora.multiplier = self._multiplier + if hasattr(self, 'text_encoder_loras'): + for lora in self.text_encoder_loras: + lora.multiplier = self._multiplier + else: + if hasattr(self, 'unet_loras'): + for lora in self.unet_loras: + lora.multiplier = 0 + if hasattr(self, 'text_encoder_loras'): + for lora in self.text_encoder_loras: + lora.multiplier = 0 + + # called when the context manager is entered + # ie: with network: + def __enter__(self: Network): + self.is_active = True + self._update_lora_multiplier() + + def __exit__(self: Network, exc_type, exc_value, tb): + self.is_active = False + self._update_lora_multiplier() + + def force_to(self: Network, device, dtype): + self.to(device, dtype) + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + for lora in loras: + lora.to(device, dtype) + + def get_all_modules(self: Network): + loras = [] + if hasattr(self, 'unet_loras'): + loras += self.unet_loras + if hasattr(self, 'text_encoder_loras'): + loras += self.text_encoder_loras + return loras + + def _update_checkpointing(self: Network): + for module in self.get_all_modules(): + if self.is_checkpointing: + module.enable_gradient_checkpointing() + else: + module.disable_gradient_checkpointing() + + # def enable_gradient_checkpointing(self: Network): + # # not supported + # self.is_checkpointing = True + # self._update_checkpointing() + # + # def disable_gradient_checkpointing(self: Network): + # # not supported + # self.is_checkpointing = False + # self._update_checkpointing() + + @property + def is_normalizing(self: Network) -> bool: + return self._is_normalizing + + @is_normalizing.setter + def is_normalizing(self: Network, value: bool): + self._is_normalizing = value + for module in self.get_all_modules(): + module.is_normalizing = self._is_normalizing + + def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0): + for module in self.get_all_modules(): + module.apply_stored_normalizer(target_normalize_scaler)