Add ability to include conv_in and conv_out to full train when doing a lora

This commit is contained in:
Jaret Burkett
2024-06-29 14:54:50 -06:00
parent 603ceca3ca
commit 3072d20f17
3 changed files with 50 additions and 0 deletions

View File

@@ -335,6 +335,7 @@ class TrainConfig:
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
if ema_config is not None:
ema_config['use_ema'] = True
print(f"Using EMA")
else:
ema_config = {'use_ema': False}

View File

@@ -1,3 +1,4 @@
import copy
import json
import math
import os
@@ -5,6 +6,7 @@ import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel
from transformers import CLIPTextModel
from .config_modules import NetworkConfig
@@ -163,6 +165,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
network_type: str = "lora",
full_train_in_out: bool = False,
**kwargs
) -> None:
"""
@@ -212,6 +215,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.module_class = DoRAModule
module_class = DoRAModule
self.full_train_in_out = full_train_in_out
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
@@ -389,3 +394,46 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
if self.full_train_in_out:
print("full train in out")
# we are going to retrain the main in out layers for VAE change usually
if self.is_pixart:
transformer: PixArtTransformer2DModel = unet
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
transformer.proj_out.orig_forward = transformer.proj_out.forward
transformer.pos_embed.forward = self.transformer_pos_embed.forward
transformer.proj_out.forward = self.transformer_proj_out.forward
else:
unet: UNet2DConditionModel = unet
unet_conv_in: torch.nn.Conv2d = unet.conv_in
unet_conv_out: torch.nn.Conv2d = unet.conv_out
# clone these and replace their forwards with ours
self.unet_conv_in = copy.deepcopy(unet_conv_in)
self.unet_conv_out = copy.deepcopy(unet_conv_out)
unet.conv_in.orig_forward = unet_conv_in.forward
unet_conv_out.orig_forward = unet_conv_out.forward
unet.conv_in.forward = self.unet_conv_in.forward
unet.conv_out.forward = self.unet_conv_out.forward
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# call Lora prepare_optimizer_params
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out:
if self.is_pixart:
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else:
all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())})
return all_params

View File

@@ -477,6 +477,7 @@ class ToolkitNetworkMixin:
v = v.detach().clone().to("cpu").to(dtype)
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
del state_dict[key]
if extra_state_dict is not None:
# add extra items to state dict