mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Fully tested and now supporting locon on sdxl. If you have the ram
This commit is contained in:
@@ -5,6 +5,7 @@ from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
@@ -468,12 +469,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
is_lycoris = False
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
NetworkClass = LoRASpecialNetwork
|
||||
if self.network_config.conv is not None and self.network_config.conv > 0:
|
||||
NetworkClass = LycorisSpecialNetwork
|
||||
is_lycoris = True
|
||||
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
|
||||
NetworkClass = LycorisSpecialNetwork
|
||||
is_lycoris = True
|
||||
|
||||
if is_lycoris:
|
||||
preset = PRESET['full']
|
||||
# NetworkClass.apply_preset(preset)
|
||||
|
||||
self.network = NetworkClass(
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -3,15 +3,13 @@ import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
|
||||
from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
|
||||
from .train_tools import get_torch_dtype
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
@@ -22,6 +20,17 @@ from torch.utils.checkpoint import checkpoint
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -197,8 +206,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_linear = child_module.__class__.__name__.in_(LINEAR_MODULES)
|
||||
is_conv2d = child_module.__class__.__name__.in_(CONV_MODULES)
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
|
||||
@@ -1,38 +1,147 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Union, List, Type
|
||||
|
||||
import torch
|
||||
from lycoris.kohya import LycorisNetwork, LoConModule
|
||||
from lycoris.modules.glora import GLoRAModule
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from torch.nn import functional as F
|
||||
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
|
||||
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: nn.Module,
|
||||
lora_name, org_module: nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4, alpha=1,
|
||||
dropout=0., rank_dropout=0., module_dropout=0.,
|
||||
use_cp=False,
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
# call super of
|
||||
super().__init__(
|
||||
lora_name,
|
||||
org_module,
|
||||
multiplier=multiplier,
|
||||
lora_dim=lora_dim, alpha=alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
use_cp=use_cp,
|
||||
**kwargs,
|
||||
call_super_init=False,
|
||||
**kwargs
|
||||
)
|
||||
# call super of super
|
||||
super(LoConModule, self).__init__()
|
||||
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
|
||||
self.scalar = nn.Parameter(torch.tensor(0.0))
|
||||
orig_module_name = org_module.__class__.__name__
|
||||
if orig_module_name in CONV_MODULES:
|
||||
self.isconv = True
|
||||
# For general LoCon
|
||||
in_dim = org_module.in_channels
|
||||
k_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
out_dim = org_module.out_channels
|
||||
self.down_op = F.conv2d
|
||||
self.up_op = F.conv2d
|
||||
if use_cp and k_size != (1, 1):
|
||||
self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
||||
self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
|
||||
self.cp = True
|
||||
else:
|
||||
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
||||
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
||||
elif orig_module_name in LINEAR_MODULES:
|
||||
self.isconv = False
|
||||
self.down_op = F.linear
|
||||
self.up_op = F.linear
|
||||
if orig_module_name == 'GroupNorm':
|
||||
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32)
|
||||
in_dim = org_module.num_channels
|
||||
out_dim = org_module.num_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
||||
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.shape = org_module.weight.shape
|
||||
|
||||
if dropout:
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.dropout = nn.Identity()
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lora_up.weight)
|
||||
if self.cp:
|
||||
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module]
|
||||
self.register_load_state_dict_post_hook(self.load_weight_hook)
|
||||
|
||||
def load_weight_hook(self, *args, **kwargs):
|
||||
self.scalar = nn.Parameter(torch.ones_like(self.scalar))
|
||||
|
||||
|
||||
class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
# 'UNet2DConditionModel',
|
||||
# 'Conv2d',
|
||||
# 'Timesteps',
|
||||
# 'TimestepEmbedding',
|
||||
# 'Linear',
|
||||
# 'SiLU',
|
||||
# 'ModuleList',
|
||||
# 'DownBlock2D',
|
||||
'ResnetBlock2D', # need
|
||||
# 'GroupNorm',
|
||||
# 'LoRACompatibleConv',
|
||||
# 'LoRACompatibleLinear',
|
||||
# 'Dropout',
|
||||
# 'CrossAttnDownBlock2D', # needed
|
||||
'Transformer2DModel', # maybe not, has duplicates
|
||||
# 'BasicTransformerBlock', # duplicates
|
||||
# 'LayerNorm',
|
||||
# 'Attention',
|
||||
# 'FeedForward',
|
||||
# 'GEGLU',
|
||||
# 'UpBlock2D',
|
||||
# 'UNetMidBlock2DCrossAttn'
|
||||
]
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
"time_embedding.linear_1",
|
||||
"time_embedding.linear_2",
|
||||
]
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
@@ -49,6 +158,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
network_module: Type[object] = LoConSpecialModule,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# call ToolkitNetworkMixin super
|
||||
super().__init__(
|
||||
**kwargs
|
||||
)
|
||||
# call the parent of the parent LycorisNetwork
|
||||
super(LycorisNetwork, self).__init__()
|
||||
|
||||
# LyCORIS unique stuff
|
||||
if dropout is None:
|
||||
dropout = 0
|
||||
@@ -57,19 +173,162 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
if module_dropout is None:
|
||||
module_dropout = 0
|
||||
|
||||
super().__init__(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
lora_dim=lora_dim,
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
alpha=alpha,
|
||||
conv_alpha=conv_alpha,
|
||||
use_cp=use_cp,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
network_module=network_module,
|
||||
**kwargs,
|
||||
)
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if not self.ENABLE_CONV:
|
||||
conv_lora_dim = 0
|
||||
|
||||
self.conv_lora_dim = int(conv_lora_dim)
|
||||
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
|
||||
print('Apply different lora dim for conv layer')
|
||||
print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
|
||||
elif self.conv_lora_dim == 0:
|
||||
print('Disable conv layer')
|
||||
|
||||
self.alpha = alpha
|
||||
self.conv_alpha = float(conv_alpha)
|
||||
if self.conv_lora_dim and self.alpha != self.conv_alpha:
|
||||
print('Apply different alpha value for conv layer')
|
||||
print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
|
||||
|
||||
if 1 >= dropout >= 0:
|
||||
print(f'Use Dropout value: {dropout}')
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
prefix,
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules,
|
||||
target_replace_names=[]
|
||||
) -> List[network_module]:
|
||||
print('Create LyCORIS Module')
|
||||
loras = []
|
||||
# remove this
|
||||
named_modules = root_module.named_modules()
|
||||
modules = root_module.modules()
|
||||
# add a few to tthe generator
|
||||
|
||||
for name, module in named_modules:
|
||||
module_name = module.__class__.__name__
|
||||
if module_name in target_replace_modules:
|
||||
if module_name in self.MODULE_ALGO_MAP:
|
||||
algo = self.MODULE_ALGO_MAP[module_name]
|
||||
else:
|
||||
algo = network_module
|
||||
for child_name, child_module in module.named_modules():
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'):
|
||||
print(f"{lora_name}")
|
||||
|
||||
if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, child_module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
**kwargs
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
k_size, *_ = child_module.kernel_size
|
||||
if k_size == 1 and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, child_module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, child_module, self.multiplier,
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
loras.append(lora)
|
||||
elif name in target_replace_names:
|
||||
if name in self.NAME_ALGO_MAP:
|
||||
algo = self.NAME_ALGO_MAP[name]
|
||||
else:
|
||||
algo = network_module
|
||||
lora_name = prefix + '.' + name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
if module.__class__.__name__ == 'Linear' and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
**kwargs
|
||||
)
|
||||
elif module.__class__.__name__ == 'Conv2d':
|
||||
k_size, *_ = module.kernel_size
|
||||
if k_size == 1 and lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, module, self.multiplier,
|
||||
self.lora_dim, self.alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
lora = algo(
|
||||
lora_name, module, self.multiplier,
|
||||
self.conv_lora_dim, self.conv_alpha,
|
||||
self.dropout, self.rank_dropout, self.module_dropout,
|
||||
use_cp,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
if network_module == GLoRAModule:
|
||||
print('GLoRA enabled, only train transformer')
|
||||
# only train transformer (for GLoRA)
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"Attention",
|
||||
]
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = []
|
||||
|
||||
if isinstance(text_encoder, list):
|
||||
text_encoders = text_encoder
|
||||
use_index = True
|
||||
else:
|
||||
text_encoders = [text_encoder]
|
||||
use_index = False
|
||||
|
||||
self.text_encoder_loras = []
|
||||
for i, te in enumerate(text_encoders):
|
||||
self.text_encoder_loras.extend(create_modules(
|
||||
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
|
||||
te,
|
||||
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
))
|
||||
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
|
||||
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
|
||||
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
self.weights_sd = None
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
@@ -21,9 +21,11 @@ class ToolkitModuleMixin:
|
||||
def __init__(
|
||||
self: Module,
|
||||
*args,
|
||||
call_super_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
@@ -74,7 +76,10 @@ class ToolkitModuleMixin:
|
||||
if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp:
|
||||
lx = self.lora_mid(self.lora_down(x))
|
||||
else:
|
||||
lx = self.lora_down(x)
|
||||
try:
|
||||
lx = self.lora_down(x)
|
||||
except RuntimeError as e:
|
||||
print(f"Error in {self.__class__.__name__} lora_down")
|
||||
|
||||
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
|
||||
lx = self.dropout(lx)
|
||||
@@ -202,7 +207,7 @@ class ToolkitNetworkMixin:
|
||||
self._is_normalizing: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
super().__init__(*args, **kwargs)
|
||||
# super().__init__(*args, **kwargs)
|
||||
|
||||
def get_keymap(self: Network):
|
||||
if self.is_sdxl:
|
||||
@@ -219,7 +224,7 @@ class ToolkitNetworkMixin:
|
||||
# check if file exists
|
||||
if os.path.exists(keymap_path):
|
||||
with open(keymap_path, 'r') as f:
|
||||
keymap = json.load(f)
|
||||
keymap = json.load(f)['ldm_diffusers_keymap']
|
||||
|
||||
return keymap
|
||||
|
||||
|
||||
@@ -367,9 +367,9 @@ class StableDiffusion:
|
||||
# was trained on 0.7 (I believe)
|
||||
|
||||
grs = gen_config.guidance_rescale
|
||||
# if grs is None or grs < 0.00001:
|
||||
# grs = 0.7
|
||||
grs = 0.0
|
||||
if grs is None or grs < 0.00001:
|
||||
grs = 0.7
|
||||
# grs = 0.0
|
||||
|
||||
extra = {}
|
||||
if sampler.startswith("sample_"):
|
||||
|
||||
Reference in New Issue
Block a user