diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index deb6cecf..4b744adf 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -335,6 +335,7 @@ class TrainConfig: ema_config: Union[Dict, None] = kwargs.get('ema_config', None) if ema_config is not None: ema_config['use_ema'] = True + print(f"Using EMA") else: ema_config = {'use_ema': False} diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 0cf77c3b..9449db47 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -1,3 +1,4 @@ +import copy import json import math import os @@ -5,6 +6,7 @@ import re import sys from typing import List, Optional, Dict, Type, Union import torch +from diffusers import UNet2DConditionModel, PixArtTransformer2DModel from transformers import CLIPTextModel from .config_modules import NetworkConfig @@ -163,6 +165,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, network_type: str = "lora", + full_train_in_out: bool = False, **kwargs ) -> None: """ @@ -212,6 +215,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.module_class = DoRAModule module_class = DoRAModule + self.full_train_in_out = full_train_in_out + if modules_dim is not None: print(f"create LoRA network from weights") elif block_dims is not None: @@ -389,3 +394,46 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): for lora in self.text_encoder_loras + self.unet_loras: assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + + if self.full_train_in_out: + print("full train in out") + # we are going to retrain the main in out layers for VAE change usually + if self.is_pixart: + transformer: PixArtTransformer2DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.pos_embed.orig_forward = transformer.pos_embed.forward + transformer.proj_out.orig_forward = transformer.proj_out.forward + + transformer.pos_embed.forward = self.transformer_pos_embed.forward + transformer.proj_out.forward = self.transformer_proj_out.forward + + else: + unet: UNet2DConditionModel = unet + unet_conv_in: torch.nn.Conv2d = unet.conv_in + unet_conv_out: torch.nn.Conv2d = unet.conv_out + + # clone these and replace their forwards with ours + self.unet_conv_in = copy.deepcopy(unet_conv_in) + self.unet_conv_out = copy.deepcopy(unet_conv_out) + unet.conv_in.orig_forward = unet_conv_in.forward + unet_conv_out.orig_forward = unet_conv_out.forward + unet.conv_in.forward = self.unet_conv_in.forward + unet.conv_out.forward = self.unet_conv_out.forward + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # call Lora prepare_optimizer_params + all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) + + if self.full_train_in_out: + if self.is_pixart: + all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) + else: + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())}) + all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())}) + + return all_params + + diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 1fc86fb4..8f996f74 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -477,6 +477,7 @@ class ToolkitNetworkMixin: v = v.detach().clone().to("cpu").to(dtype) save_key = save_keymap[key] if key in save_keymap else key save_dict[save_key] = v + del state_dict[key] if extra_state_dict is not None: # add extra items to state dict