mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Add ability to include conv_in and conv_out to full train when doing a lora
This commit is contained in:
@@ -335,6 +335,7 @@ class TrainConfig:
|
||||
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
|
||||
if ema_config is not None:
|
||||
ema_config['use_ema'] = True
|
||||
print(f"Using EMA")
|
||||
else:
|
||||
ema_config = {'use_ema': False}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -5,6 +6,7 @@ import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .config_modules import NetworkConfig
|
||||
@@ -163,6 +165,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
network_type: str = "lora",
|
||||
full_train_in_out: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -212,6 +215,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
|
||||
self.full_train_in_out = full_train_in_out
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
@@ -389,3 +394,46 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
if self.full_train_in_out:
|
||||
print("full train in out")
|
||||
# we are going to retrain the main in out layers for VAE change usually
|
||||
if self.is_pixart:
|
||||
transformer: PixArtTransformer2DModel = unet
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.pos_embed.orig_forward = transformer.pos_embed.forward
|
||||
transformer.proj_out.orig_forward = transformer.proj_out.forward
|
||||
|
||||
transformer.pos_embed.forward = self.transformer_pos_embed.forward
|
||||
transformer.proj_out.forward = self.transformer_proj_out.forward
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
unet_conv_in: torch.nn.Conv2d = unet.conv_in
|
||||
unet_conv_out: torch.nn.Conv2d = unet.conv_out
|
||||
|
||||
# clone these and replace their forwards with ours
|
||||
self.unet_conv_in = copy.deepcopy(unet_conv_in)
|
||||
self.unet_conv_out = copy.deepcopy(unet_conv_out)
|
||||
unet.conv_in.orig_forward = unet_conv_in.forward
|
||||
unet_conv_out.orig_forward = unet_conv_out.forward
|
||||
unet.conv_in.forward = self.unet_conv_in.forward
|
||||
unet.conv_out.forward = self.unet_conv_out.forward
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
# call Lora prepare_optimizer_params
|
||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||
|
||||
if self.full_train_in_out:
|
||||
if self.is_pixart:
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||
else:
|
||||
all_params.append({"lr": unet_lr, "params": list(self.unet_conv_in.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.unet_conv_out.parameters())})
|
||||
|
||||
return all_params
|
||||
|
||||
|
||||
|
||||
@@ -477,6 +477,7 @@ class ToolkitNetworkMixin:
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
save_key = save_keymap[key] if key in save_keymap else key
|
||||
save_dict[save_key] = v
|
||||
del state_dict[key]
|
||||
|
||||
if extra_state_dict is not None:
|
||||
# add extra items to state dict
|
||||
|
||||
Reference in New Issue
Block a user