import math import os import re import sys from typing import List, Optional, Dict, Type, Union import torch from transformers import CLIPTextModel from .paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT) from networks.lora import LoRANetwork, get_block_index from torch.utils.checkpoint import checkpoint RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") class LoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ def __init__( self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None, rank_dropout=None, module_dropout=None, ): """if alpha == 0 or None, alpha is rank (no scaling).""" super().__init__() self.lora_name = lora_name if org_module.__class__.__name__ == "Conv2d": in_dim = org_module.in_channels out_dim = org_module.out_channels else: in_dim = org_module.in_features out_dim = org_module.out_features # if limit_rank: # self.lora_dim = min(lora_dim, in_dim, out_dim) # if self.lora_dim != lora_dim: # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim if org_module.__class__.__name__ == "Conv2d": kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(self.lora_up.weight) self.multiplier: Union[float, List[float]] = multiplier self.org_module = org_module # remove in applying self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward del self.org_module # this allows us to set different multipliers on a per item in a batch basis # allowing us to run positive and negative weights in the same batch # really only useful for slider training for now def get_multiplier(self, lora_up): batch_size = lora_up.size(0) # batch will have all negative prompts first and positive prompts second # our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts # if there is more than our multiplier, it is liekly a batch size increase, so we need to # interleve the multipliers if isinstance(self.multiplier, list): if len(self.multiplier) == 0: # single item, just return it return self.multiplier[0] elif len(self.multiplier) == batch_size: # not doing CFG multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype) else: # we have a list of multipliers, so we need to get the multiplier for this batch multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype) # should be 1 for if total batch size was 1 num_interleaves = (batch_size // 2) // len(self.multiplier) multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves) # match lora_up rank if len(lora_up.size()) == 2: multiplier_tensor = multiplier_tensor.view(-1, 1) elif len(lora_up.size()) == 3: multiplier_tensor = multiplier_tensor.view(-1, 1, 1) elif len(lora_up.size()) == 4: multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1) return multiplier_tensor else: return self.multiplier def _call_forward(self, x): # module dropout if self.module_dropout is not None and self.training: if torch.rand(1) < self.module_dropout: return 0.0 # added to original forward lx = self.lora_down(x) # normal dropout if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout if len(lx.size()) == 3: mask = mask.unsqueeze(1) # for Text Encoder elif len(lx.size()) == 4: mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d lx = lx * mask # scaling for rank dropout: treat as if the rank is changed # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability else: scale = self.scale lx = self.lora_up(lx) multiplier = self.get_multiplier(lx) return lx * multiplier * scale def create_custom_forward(self): def custom_forward(*inputs): return self._call_forward(*inputs) return custom_forward def forward(self, x): org_forwarded = self.org_forward(x) # TODO this just loses the grad. Not sure why. Probably why no one else is doing it either # if torch.is_grad_enabled() and self.is_checkpointing and self.training: # lora_output = checkpoint( # self.create_custom_forward(), # x, # ) # else: # lora_output = self._call_forward(x) lora_output = self._call_forward(x) return org_forwarded + lora_output def enable_gradient_checkpointing(self): self.is_checkpointing = True def disable_gradient_checkpointing(self): self.is_checkpointing = False class LoRASpecialNetwork(LoRANetwork): _multiplier: float = 1.0 is_active: bool = False NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER LORA_PREFIX_TEXT_ENCODER1 = "lora_te1" LORA_PREFIX_TEXT_ENCODER2 = "lora_te2" def __init__( self, text_encoder: Union[List[CLIPTextModel], CLIPTextModel], unet, multiplier: float = 1.0, lora_dim: int = 4, alpha: float = 1, dropout: Optional[float] = None, rank_dropout: Optional[float] = None, module_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, block_dims: Optional[List[int]] = None, block_alphas: Optional[List[float]] = None, conv_block_dims: Optional[List[int]] = None, conv_block_alphas: Optional[List[float]] = None, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, train_text_encoder: Optional[bool] = True, train_unet: Optional[bool] = True, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り 1. lora_dimとalphaを指定 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する 5. modules_dimとmodules_alphaを指定 (推論用) """ # call the parent of the parent we are replacing (LoRANetwork) init super(LoRANetwork, self).__init__() self.multiplier = multiplier self.lora_dim = lora_dim self.alpha = alpha self.conv_lora_dim = conv_lora_dim self.conv_alpha = conv_alpha self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False if modules_dim is not None: print(f"create LoRA network from weights") elif block_dims is not None: print(f"create LoRA network from block_dims") print( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") print(f"block_dims: {block_dims}") print(f"block_alphas: {block_alphas}") if conv_block_dims is not None: print(f"conv_block_dims: {conv_block_dims}") print(f"conv_block_alphas: {conv_block_alphas}") else: print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") print( f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: print( f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( is_unet: bool, text_encoder_idx: Optional[int], # None, 1, 2 root_module: torch.nn.Module, target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: prefix = ( self.LORA_PREFIX_UNET if is_unet else ( self.LORA_PREFIX_TEXT_ENCODER if text_encoder_idx is None else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2) ) ) loras = [] skipped = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") dim = None alpha = None if modules_dim is not None: # モジュール指定あり if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: # U-Netでblock_dims指定あり block_idx = get_block_index(lora_name) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] alpha = block_alphas[block_idx] elif conv_block_dims is not None: dim = conv_block_dims[block_idx] alpha = conv_block_alphas[block_idx] else: # 通常、すべて対象とする if is_linear or is_conv2d_1x1: dim = self.lora_dim alpha = self.alpha elif self.conv_lora_dim is not None: dim = self.conv_lora_dim alpha = self.conv_alpha if dim is None or dim == 0: # skipした情報を出力 if is_linear or is_conv2d_1x1 or ( self.conv_lora_dim is not None or conv_block_dims is not None): skipped.append(lora_name) continue lora = module_class( lora_name, child_module, self.multiplier, dim, alpha, dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, ) loras.append(lora) return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] # create LoRA for text encoder # 毎回すべてのモジュールを作るのは無駄なので要検討 self.text_encoder_loras = [] skipped_te = [] if train_text_encoder: for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: index = i + 1 print(f"create LoRA for Text Encoder {index}:") else: index = None print(f"create LoRA for Text Encoder:") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) else: self.unet_loras = [] skipped_un = [] print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: print( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: print(f"\t{name}") self.up_lr_weight: List[float] = None self.down_lr_weight: List[float] = None self.mid_lr_weight: float = None self.block_lr = False # assertion names = set() for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None state_dict = self.state_dict() if dtype is not None: for key in list(state_dict.keys()): v = state_dict[key] v = v.detach().clone().to("cpu").to(dtype) state_dict[key] = v if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import save_file save_file(state_dict, file, metadata) else: torch.save(state_dict, file) @property def multiplier(self) -> Union[float, List[float]]: return self._multiplier @multiplier.setter def multiplier(self, value: Union[float, List[float]]): self._multiplier = value self._update_lora_multiplier() def _update_lora_multiplier(self): if self.is_active: if hasattr(self, 'unet_loras'): for lora in self.unet_loras: lora.multiplier = self._multiplier if hasattr(self, 'text_encoder_loras'): for lora in self.text_encoder_loras: lora.multiplier = self._multiplier else: if hasattr(self, 'unet_loras'): for lora in self.unet_loras: lora.multiplier = 0 if hasattr(self, 'text_encoder_loras'): for lora in self.text_encoder_loras: lora.multiplier = 0 # called when the context manager is entered # ie: with network: def __enter__(self): self.is_active = True self._update_lora_multiplier() def __exit__(self, exc_type, exc_value, tb): self.is_active = False self._update_lora_multiplier() def force_to(self, device, dtype): self.to(device, dtype) loras = [] if hasattr(self, 'unet_loras'): loras += self.unet_loras if hasattr(self, 'text_encoder_loras'): loras += self.text_encoder_loras for lora in loras: lora.to(device, dtype) def _update_checkpointing(self): if self.is_checkpointing: if hasattr(self, 'unet_loras'): for lora in self.unet_loras: lora.enable_gradient_checkpointing() if hasattr(self, 'text_encoder_loras'): for lora in self.text_encoder_loras: lora.enable_gradient_checkpointing() else: if hasattr(self, 'unet_loras'): for lora in self.unet_loras: lora.disable_gradient_checkpointing() if hasattr(self, 'text_encoder_loras'): for lora in self.text_encoder_loras: lora.disable_gradient_checkpointing() def enable_gradient_checkpointing(self): # not supported self.is_checkpointing = True self._update_checkpointing() def disable_gradient_checkpointing(self): # not supported self.is_checkpointing = False self._update_checkpointing()