Fully tested and now supporting locon on sdxl. If you have the ram

This commit is contained in:
Jaret Burkett
2023-09-04 14:05:10 -06:00
parent a4c3507a62
commit 64a5441832
5 changed files with 320 additions and 39 deletions

View File

@@ -5,6 +5,7 @@ from collections import OrderedDict
import os
from typing import Union
from lycoris.config import PRESET
from torch.utils.data import DataLoader
from toolkit.data_loader import get_dataloader_from_datasets
@@ -468,12 +469,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
is_lycoris = False
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.conv is not None and self.network_config.conv > 0:
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
if is_lycoris:
preset = PRESET['full']
# NetworkClass.apply_preset(preset)
self.network = NetworkClass(
text_encoder=text_encoder,

View File

@@ -3,15 +3,13 @@ import math
import os
import re
import sys
from collections import OrderedDict
from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
from .train_tools import get_torch_dtype
from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
@@ -22,6 +20,17 @@ from torch.utils.checkpoint import checkpoint
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
]
class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -197,8 +206,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_linear = child_module.__class__.__name__.in_(LINEAR_MODULES)
is_conv2d = child_module.__class__.__name__.in_(CONV_MODULES)
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:

View File

@@ -1,38 +1,147 @@
import math
import os
from typing import Optional, Union, List, Type
import torch
from lycoris.kohya import LycorisNetwork, LoConModule
from lycoris.modules.glora import GLoRAModule
from torch import nn
from transformers import CLIPTextModel
from torch.nn import functional as F
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
# diffusers specific stuff
LINEAR_MODULES = [
'Linear',
'LoRACompatibleLinear'
# 'GroupNorm',
]
CONV_MODULES = [
'Conv2d',
'LoRACompatibleConv'
]
class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
def __init__(
self,
lora_name,
org_module: nn.Module,
lora_name, org_module: nn.Module,
multiplier=1.0,
lora_dim=4, alpha=1,
dropout=0., rank_dropout=0., module_dropout=0.,
use_cp=False,
**kwargs,
):
""" if alpha == 0 or None, alpha is rank (no scaling). """
# call super of
super().__init__(
lora_name,
org_module,
multiplier=multiplier,
lora_dim=lora_dim, alpha=alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
use_cp=use_cp,
**kwargs,
call_super_init=False,
**kwargs
)
# call super of super
super(LoConModule, self).__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False
self.scalar = nn.Parameter(torch.tensor(0.0))
orig_module_name = org_module.__class__.__name__
if orig_module_name in CONV_MODULES:
self.isconv = True
# For general LoCon
in_dim = org_module.in_channels
k_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
out_dim = org_module.out_channels
self.down_op = F.conv2d
self.up_op = F.conv2d
if use_cp and k_size != (1, 1):
self.lora_down = nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
self.lora_mid = nn.Conv2d(lora_dim, lora_dim, k_size, stride, padding, bias=False)
self.cp = True
else:
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
elif orig_module_name in LINEAR_MODULES:
self.isconv = False
self.down_op = F.linear
self.up_op = F.linear
if orig_module_name == 'GroupNorm':
# RuntimeError: mat1 and mat2 shapes cannot be multiplied (56320x120 and 320x32)
in_dim = org_module.num_channels
out_dim = org_module.num_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
else:
raise NotImplementedError
self.shape = org_module.weight.shape
if dropout:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = nn.Identity()
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lora_up.weight)
if self.cp:
torch.nn.init.kaiming_uniform_(self.lora_mid.weight, a=math.sqrt(5))
self.multiplier = multiplier
self.org_module = [org_module]
self.register_load_state_dict_post_hook(self.load_weight_hook)
def load_weight_hook(self, *args, **kwargs):
self.scalar = nn.Parameter(torch.ones_like(self.scalar))
class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
# 'UNet2DConditionModel',
# 'Conv2d',
# 'Timesteps',
# 'TimestepEmbedding',
# 'Linear',
# 'SiLU',
# 'ModuleList',
# 'DownBlock2D',
'ResnetBlock2D', # need
# 'GroupNorm',
# 'LoRACompatibleConv',
# 'LoRACompatibleLinear',
# 'Dropout',
# 'CrossAttnDownBlock2D', # needed
'Transformer2DModel', # maybe not, has duplicates
# 'BasicTransformerBlock', # duplicates
# 'LayerNorm',
# 'Attention',
# 'FeedForward',
# 'GEGLU',
# 'UpBlock2D',
# 'UNetMidBlock2DCrossAttn'
]
UNET_TARGET_REPLACE_NAME = [
"conv_in",
"conv_out",
"time_embedding.linear_1",
"time_embedding.linear_2",
]
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
@@ -49,6 +158,13 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
network_module: Type[object] = LoConSpecialModule,
**kwargs,
) -> None:
# call ToolkitNetworkMixin super
super().__init__(
**kwargs
)
# call the parent of the parent LycorisNetwork
super(LycorisNetwork, self).__init__()
# LyCORIS unique stuff
if dropout is None:
dropout = 0
@@ -57,19 +173,162 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
if module_dropout is None:
module_dropout = 0
super().__init__(
text_encoder,
unet,
multiplier=multiplier,
lora_dim=lora_dim,
conv_lora_dim=conv_lora_dim,
alpha=alpha,
conv_alpha=conv_alpha,
use_cp=use_cp,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
network_module=network_module,
**kwargs,
)
self.multiplier = multiplier
self.lora_dim = lora_dim
if not self.ENABLE_CONV:
conv_lora_dim = 0
self.conv_lora_dim = int(conv_lora_dim)
if self.conv_lora_dim and self.conv_lora_dim != self.lora_dim:
print('Apply different lora dim for conv layer')
print(f'Conv Dim: {conv_lora_dim}, Linear Dim: {lora_dim}')
elif self.conv_lora_dim == 0:
print('Disable conv layer')
self.alpha = alpha
self.conv_alpha = float(conv_alpha)
if self.conv_lora_dim and self.alpha != self.conv_alpha:
print('Apply different alpha value for conv layer')
print(f'Conv alpha: {conv_alpha}, Linear alpha: {alpha}')
if 1 >= dropout >= 0:
print(f'Use Dropout value: {dropout}')
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
# create module instances
def create_modules(
prefix,
root_module: torch.nn.Module,
target_replace_modules,
target_replace_names=[]
) -> List[network_module]:
print('Create LyCORIS Module')
loras = []
# remove this
named_modules = root_module.named_modules()
modules = root_module.modules()
# add a few to tthe generator
for name, module in named_modules:
module_name = module.__class__.__name__
if module_name in target_replace_modules:
if module_name in self.MODULE_ALGO_MAP:
algo = self.MODULE_ALGO_MAP[module_name]
else:
algo = network_module
for child_name, child_module in module.named_modules():
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
if lora_name.startswith('lora_unet_input_blocks_1_0_emb_layers_1'):
print(f"{lora_name}")
if child_module.__class__.__name__ in LINEAR_MODULES and lora_dim > 0:
lora = algo(
lora_name, child_module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
**kwargs
)
elif child_module.__class__.__name__ in CONV_MODULES:
k_size, *_ = child_module.kernel_size
if k_size == 1 and lora_dim > 0:
lora = algo(
lora_name, child_module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
**kwargs
)
elif conv_lora_dim > 0:
lora = algo(
lora_name, child_module, self.multiplier,
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
**kwargs
)
else:
continue
else:
continue
loras.append(lora)
elif name in target_replace_names:
if name in self.NAME_ALGO_MAP:
algo = self.NAME_ALGO_MAP[name]
else:
algo = network_module
lora_name = prefix + '.' + name
lora_name = lora_name.replace('.', '_')
if module.__class__.__name__ == 'Linear' and lora_dim > 0:
lora = algo(
lora_name, module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
**kwargs
)
elif module.__class__.__name__ == 'Conv2d':
k_size, *_ = module.kernel_size
if k_size == 1 and lora_dim > 0:
lora = algo(
lora_name, module, self.multiplier,
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
**kwargs
)
elif conv_lora_dim > 0:
lora = algo(
lora_name, module, self.multiplier,
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
**kwargs
)
else:
continue
else:
continue
loras.append(lora)
return loras
if network_module == GLoRAModule:
print('GLoRA enabled, only train transformer')
# only train transformer (for GLoRA)
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE = [
"Transformer2DModel",
"Attention",
]
LycorisSpecialNetwork.UNET_TARGET_REPLACE_NAME = []
if isinstance(text_encoder, list):
text_encoders = text_encoder
use_index = True
else:
text_encoders = [text_encoder]
use_index = False
self.text_encoder_loras = []
for i, te in enumerate(text_encoders):
self.text_encoder_loras.extend(create_modules(
LycorisSpecialNetwork.LORA_PREFIX_TEXT_ENCODER + (f'{i + 1}' if use_index else ''),
te,
LycorisSpecialNetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
))
print(f"create LyCORIS for Text Encoder: {len(self.text_encoder_loras)} modules.")
self.unet_loras = create_modules(LycorisSpecialNetwork.LORA_PREFIX_UNET, unet,
LycorisSpecialNetwork.UNET_TARGET_REPLACE_MODULE)
print(f"create LyCORIS for U-Net: {len(self.unet_loras)} modules.")
self.weights_sd = None
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

View File

@@ -21,9 +21,11 @@ class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
call_super_init: bool = True,
**kwargs
):
super().__init__(*args, **kwargs)
if call_super_init:
super().__init__(*args, **kwargs)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
@@ -74,7 +76,10 @@ class ToolkitModuleMixin:
if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp:
lx = self.lora_mid(self.lora_down(x))
else:
lx = self.lora_down(x)
try:
lx = self.lora_down(x)
except RuntimeError as e:
print(f"Error in {self.__class__.__name__} lora_down")
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
lx = self.dropout(lx)
@@ -202,7 +207,7 @@ class ToolkitNetworkMixin:
self._is_normalizing: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
super().__init__(*args, **kwargs)
# super().__init__(*args, **kwargs)
def get_keymap(self: Network):
if self.is_sdxl:
@@ -219,7 +224,7 @@ class ToolkitNetworkMixin:
# check if file exists
if os.path.exists(keymap_path):
with open(keymap_path, 'r') as f:
keymap = json.load(f)
keymap = json.load(f)['ldm_diffusers_keymap']
return keymap

View File

@@ -367,9 +367,9 @@ class StableDiffusion:
# was trained on 0.7 (I believe)
grs = gen_config.guidance_rescale
# if grs is None or grs < 0.00001:
# grs = 0.7
grs = 0.0
if grs is None or grs < 0.00001:
grs = 0.7
# grs = 0.0
extra = {}
if sampler.startswith("sample_"):