Added LoCON from LyCORIS

This commit is contained in:
Jaret Burkett
2023-09-04 08:48:07 -06:00
parent fa8fc32c0a
commit a4c3507a62
7 changed files with 506 additions and 556 deletions

View File

@@ -93,6 +93,7 @@ class SDTrainer(BaseSDTrainProcess):
# back propagate loss to free ram
loss.backward()
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
flush()
# apply gradients

View File

@@ -1,5 +1,6 @@
import copy
import glob
import inspect
from collections import OrderedDict
import os
from typing import Union
@@ -10,6 +11,8 @@ from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.sampler import get_sampler
@@ -74,6 +77,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.params = []
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
@@ -120,7 +124,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
# to hold network if there is one
self.network = None
self.network: Union[Network, None] = None
self.embedding = None
def sample(self, step=None, is_first=False):
@@ -424,25 +428,54 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_scheduler = self.sd.noise_scheduler
if self.train_config.xformers:
vae.set_use_memory_efficient_attention_xformers(True)
vae.enable_xformers_memory_efficient_attention()
unet.enable_xformers_memory_efficient_attention()
if isinstance(text_encoder, list):
for te in text_encoder:
# if it has it
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
te.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# if isinstance(text_encoder, list):
# for te in text_encoder:
# te.enable_gradient_checkpointing()
# else:
# text_encoder.enable_gradient_checkpointing()
if isinstance(text_encoder, list):
for te in text_encoder:
if hasattr(te, 'enable_gradient_checkpointing'):
te.enable_gradient_checkpointing()
if hasattr(te, "gradient_checkpointing_enable"):
te.gradient_checkpointing_enable()
else:
if hasattr(text_encoder, 'enable_gradient_checkpointing'):
text_encoder.enable_gradient_checkpointing()
if hasattr(text_encoder, "gradient_checkpointing_enable"):
text_encoder.gradient_checkpointing_enable()
if isinstance(text_encoder, list):
for te in text_encoder:
te.requires_grad_(False)
te.eval()
else:
text_encoder.requires_grad_(False)
text_encoder.eval()
unet.to(self.device_torch, dtype=dtype)
unet.requires_grad_(False)
unet.eval()
vae = vae.to(torch.device('cpu'), dtype=dtype)
vae.requires_grad_(False)
vae.eval()
flush()
if self.network_config is not None:
self.network = LoRASpecialNetwork(
# TODO should we completely switch to LycorisSpecialNetwork?
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.conv is not None and self.network_config.conv > 0:
NetworkClass = LycorisSpecialNetwork
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
self.network = NetworkClass(
text_encoder=text_encoder,
unet=unet,
lora_dim=self.network_config.linear,
@@ -468,14 +501,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
self.network.prepare_grad_etc(text_encoder, unet)
flush()
params = self.get_params()
if not params:
# LyCORIS doesnt have default_lr
config = {
'text_encoder_lr': self.train_config.lr,
'unet_lr': self.train_config.lr,
}
sig = inspect.signature(self.network.prepare_optimizer_params)
if 'default_lr' in sig.parameters:
config['default_lr'] = self.train_config.lr
params = self.network.prepare_optimizer_params(
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
**config
)
if self.train_config.gradient_checkpointing:
@@ -490,6 +530,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
self.network.multiplier = 1.0
flush()
elif self.embed_config is not None:
self.embedding = Embedding(
sd=self.sd,
@@ -508,7 +550,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
else:
# set them to train or not
if self.train_config.train_unet:
@@ -546,9 +588,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
flush()
### HOOK ###
params = self.hook_add_extra_train_params(params)
self.params = []
for param in params:
if isinstance(param, dict):
self.params += param['params']
else:
self.params.append(param)
optimizer_type = self.train_config.optimizer.lower()
optimizer = get_optimizer(params, optimizer_type, learning_rate=self.train_config.lr,
@@ -568,6 +617,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
self.lr_scheduler = lr_scheduler
flush()
### HOOK ###
self.hook_before_train_loop()
@@ -639,7 +689,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# turn on normalization if we are using it and it is not on
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
self.network.is_normalizing = True
flush()
### HOOK ###
loss_dict = self.hook_train_loop(batch)
flush()

View File

@@ -37,9 +37,11 @@ class SampleConfig:
self.ext: ImgExt = kwargs.get('format', 'jpg')
NetworkType = Literal['lora', 'locon']
class NetworkConfig:
def __init__(self, **kwargs):
self.type: str = kwargs.get('type', 'lora')
self.type: NetworkType = kwargs.get('type', 'lora')
rank = kwargs.get('rank', None)
linear = kwargs.get('linear', None)
if rank is not None:
@@ -86,6 +88,7 @@ class TrainConfig:
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
self.max_grad_norm = kwargs.get('max_grad_norm', 1.0)
class ModelConfig:

View File

@@ -1,243 +0,0 @@
# ref:
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
# - https://github.com/p1atdev/LECO/blob/main/lora.py
import os
import math
from typing import Optional, List, Type, Set, Literal
from collections import OrderedDict
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
from toolkit.metadata import add_model_hash_to_meta
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
"Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
]
UNET_TARGET_REPLACE_MODULE_CONV = [
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
] # locon, 3clier
LORA_PREFIX_UNET = "lora_unet"
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
TRAINING_METHODS = Literal[
"noxattn", # train all layers except x-attns and time_embed layers
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
# "notime",
# "xlayer",
# "outxattn",
# "outsattn",
# "inxattn",
# "inmidsattn",
# "selflayer",
]
class LoRAModule(nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Linear":
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
elif org_module.__class__.__name__ == "Conv2d": # 一応
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
if self.lora_dim != lora_dim:
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
class LoRANetwork(nn.Module):
def __init__(
self,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: float = 1.0,
alpha: float = 1.0,
train_method: TRAINING_METHODS = "full",
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = rank
self.alpha = alpha
# LoRAのみ
self.module = LoRAModule
# unetのloraを作る
self.unet_loras = self.create_modules(
LORA_PREFIX_UNET,
unet,
DEFAULT_TARGET_REPLACE,
self.lora_dim,
self.multiplier,
train_method=train_method,
)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
# assertion 名前の被りがないか確認しているようだ
lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in lora_names
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
lora_names.add(lora.lora_name)
# 適用する
for lora in self.unet_loras:
lora.apply_to()
self.add_module(
lora.lora_name,
lora,
)
del unet
torch.cuda.empty_cache()
def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
loras = []
for name, module in root_module.named_modules():
if train_method == "noxattn": # Cross Attention と Time Embed 以外学習
if "attn2" in name or "time_embed" in name:
continue
elif train_method == "innoxattn": # Cross Attention 以外学習
if "attn2" in name:
continue
elif train_method == "selfattn": # Self Attention のみ学習
if "attn1" not in name:
continue
elif train_method == "xattn": # Cross Attention のみ学習
if "attn2" not in name:
continue
elif train_method == "full": # 全部学習
pass
else:
raise NotImplementedError(
f"train_method: {train_method} is not implemented."
)
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ in ["Linear", "Conv2d"]:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
print(f"{lora_name}")
lora = self.module(
lora_name, child_module, multiplier, rank, self.alpha
)
loras.append(lora)
return loras
def prepare_optimizer_params(self):
all_params = []
if self.unet_loras: # 実質これしかない
params = []
[params.extend(lora.parameters()) for lora in self.unet_loras]
param_data = {"params": params}
all_params.append(param_data)
return all_params
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()
if metadata is None:
metadata = OrderedDict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
for key in list(state_dict.keys()):
if not key.startswith("lora"):
# remove any not lora
del state_dict[key]
metadata = add_model_hash_to_meta(state_dict, metadata)
if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def __enter__(self):
for lora in self.unet_loras:
lora.multiplier = 1.0
def __exit__(self, exc_type, exc_value, tb):
for lora in self.unet_loras:
lora.multiplier = 0

View File

@@ -9,6 +9,7 @@ from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
from .paths import SD_SCRIPTS_ROOT, KEYMAPS_ROOT
from .train_tools import get_torch_dtype
@@ -21,7 +22,7 @@ from torch.utils.checkpoint import checkpoint
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class LoRAModule(torch.nn.Module):
class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
@@ -40,6 +41,7 @@ class LoRAModule(torch.nn.Module):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
@@ -89,153 +91,8 @@ class LoRAModule(torch.nn.Module):
self.org_module.forward = self.forward
del self.org_module
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
# really only useful for slider training for now
def get_multiplier(self, lora_up):
with torch.no_grad():
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is likely a batch size increase, so we need to
# interleave the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
return self.multiplier[0]
elif len(self.multiplier) == batch_size:
# not doing CFG
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
else:
# we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank
if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor.detach()
else:
return self.multiplier
def _call_forward(self, x):
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return lx * scale
def forward(self, x):
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
multiplier = self.get_multiplier(lora_output)
if self.is_normalizing:
with torch.no_grad():
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
if isinstance(multiplier, torch.Tensor):
norm_multiplier = multiplier.clone().detach() * 10
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
else:
norm_multiplier = multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
normalize_scaler = normalize_scaler.detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output *= normalize_scaler
return org_forwarded + (lora_output * multiplier)
def enable_gradient_checkpointing(self):
self.is_checkpointing = True
def disable_gradient_checkpointing(self):
self.is_checkpointing = False
@torch.no_grad()
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
"""
Applied the previous normalization calculation to the module.
This must be called before saving or normalization will be lost.
It is probably best to call after each batch as well.
We just scale the up down weights to match this vector
:return:
"""
# get state dict
state_dict = self.state_dict()
dtype = state_dict['lora_up.weight'].dtype
device = state_dict['lora_up.weight'].device
# todo should we do this at fp32?
if isinstance(self.normalize_scaler, torch.Tensor):
scaler = self.normalize_scaler.clone().detach()
else:
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
total_module_scale = scaler / target_normalize_scaler
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to(device, dtype=dtype)
# apply the scaler to the up and down weights
for key in state_dict.keys():
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# do it inplace do params are updated
state_dict[key] *= up_down_scale
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
class LoRASpecialNetwork(LoRANetwork):
class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
@@ -445,154 +302,3 @@ class LoRASpecialNetwork(LoRANetwork):
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def get_keymap(self):
if self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
keymap = None
# check if file exists
if os.path.exists(keymap_name):
with open(keymap_name, 'r') as f:
keymap = json.load(f)
return keymap
def save_weights(self, file, dtype, metadata):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
save_dict = OrderedDict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(save_dict, file, metadata)
else:
torch.save(save_dict, file)
def load_weights(self, file):
# allows us to save and load to and from ldm weights
keymap = self.get_keymap()
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
load_sd = OrderedDict()
for key, value in weights_sd.items():
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
info = self.load_state_dict(load_sd, False)
return info
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
self._multiplier = value
self._update_lora_multiplier()
def _update_lora_multiplier(self):
if self.is_active:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = self._multiplier
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = self._multiplier
else:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = 0
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = 0
# called when the context manager is entered
# ie: with network:
def __enter__(self):
self.is_active = True
self._update_lora_multiplier()
def __exit__(self, exc_type, exc_value, tb):
self.is_active = False
self._update_lora_multiplier()
def force_to(self, device, dtype):
self.to(device, dtype)
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
for lora in loras:
lora.to(device, dtype)
def get_all_modules(self):
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
return loras
def _update_checkpointing(self):
for module in self.get_all_modules():
if self.is_checkpointing:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
def enable_gradient_checkpointing(self):
# not supported
self.is_checkpointing = True
self._update_checkpointing()
def disable_gradient_checkpointing(self):
# not supported
self.is_checkpointing = False
self._update_checkpointing()
@property
def is_normalizing(self) -> bool:
return self._is_normalizing
@is_normalizing.setter
def is_normalizing(self, value: bool):
self._is_normalizing = value
for module in self.get_all_modules():
module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)

View File

@@ -0,0 +1,75 @@
import os
from typing import Optional, Union, List, Type
from lycoris.kohya import LycorisNetwork, LoConModule
from torch import nn
from transformers import CLIPTextModel
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
def __init__(
self,
lora_name,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4, alpha=1,
dropout=0., rank_dropout=0., module_dropout=0.,
use_cp=False,
**kwargs,
):
super().__init__(
lora_name,
org_module,
multiplier=multiplier,
lora_dim=lora_dim, alpha=alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
use_cp=use_cp,
**kwargs,
)
class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
use_cp: Optional[bool] = False,
network_module: Type[object] = LoConSpecialModule,
**kwargs,
) -> None:
# LyCORIS unique stuff
if dropout is None:
dropout = 0
if rank_dropout is None:
rank_dropout = 0
if module_dropout is None:
module_dropout = 0
super().__init__(
text_encoder,
unet,
multiplier=multiplier,
lora_dim=lora_dim,
conv_lora_dim=conv_lora_dim,
alpha=alpha,
conv_alpha=conv_alpha,
use_cp=use_cp,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
network_module=network_module,
**kwargs,
)

358
toolkit/network_mixins.py Normal file
View File

@@ -0,0 +1,358 @@
import json
import os
from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING
import torch
from torch import nn
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
Module = Union['LoConSpecialModule', 'LoRAModule']
class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
# this allows us to set different multipliers on a per item in a batch basis
# allowing us to run positive and negative weights in the same batch
# really only useful for slider training for now
def get_multiplier(self: Module, lora_up):
with torch.no_grad():
batch_size = lora_up.size(0)
# batch will have all negative prompts first and positive prompts second
# our multiplier list is for a prompt pair. So we need to repeat it for positive and negative prompts
# if there is more than our multiplier, it is likely a batch size increase, so we need to
# interleave the multipliers
if isinstance(self.multiplier, list):
if len(self.multiplier) == 0:
# single item, just return it
return self.multiplier[0]
elif len(self.multiplier) == batch_size:
# not doing CFG
multiplier_tensor = torch.tensor(self.multiplier).to(lora_up.device, dtype=lora_up.dtype)
else:
# we have a list of multipliers, so we need to get the multiplier for this batch
multiplier_tensor = torch.tensor(self.multiplier * 2).to(lora_up.device, dtype=lora_up.dtype)
# should be 1 for if total batch size was 1
num_interleaves = (batch_size // 2) // len(self.multiplier)
multiplier_tensor = multiplier_tensor.repeat_interleave(num_interleaves)
# match lora_up rank
if len(lora_up.size()) == 2:
multiplier_tensor = multiplier_tensor.view(-1, 1)
elif len(lora_up.size()) == 3:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1)
elif len(lora_up.size()) == 4:
multiplier_tensor = multiplier_tensor.view(-1, 1, 1, 1)
return multiplier_tensor.detach()
else:
return self.multiplier
def _call_forward(self: Module, x):
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return 0.0 # added to original forward
if hasattr(self, 'lora_mid') and hasattr(self, 'cp') and self.cp:
lx = self.lora_mid(self.lora_down(x))
else:
lx = self.lora_down(x)
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
lx = self.dropout(lx)
# normal dropout
elif self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.rank_dropout > 0 and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
scale *= self.scalar
return lx * scale
def forward(self: Module, x):
org_forwarded = self.org_forward(x)
lora_output = self._call_forward(x)
multiplier = self.get_multiplier(lora_output)
if self.is_normalizing:
with torch.no_grad():
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
if isinstance(multiplier, torch.Tensor):
norm_multiplier = multiplier.clone().detach() * 10
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
else:
norm_multiplier = multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(
torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
normalize_scaler = normalize_scaler.detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output *= normalize_scaler
return org_forwarded + (lora_output * multiplier)
def enable_gradient_checkpointing(self: Module):
self.is_checkpointing = True
def disable_gradient_checkpointing(self: Module):
self.is_checkpointing = False
@torch.no_grad()
def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0):
"""
Applied the previous normalization calculation to the module.
This must be called before saving or normalization will be lost.
It is probably best to call after each batch as well.
We just scale the up down weights to match this vector
:return:
"""
# get state dict
state_dict = self.state_dict()
dtype = state_dict['lora_up.weight'].dtype
device = state_dict['lora_up.weight'].device
# todo should we do this at fp32?
if isinstance(self.normalize_scaler, torch.Tensor):
scaler = self.normalize_scaler.clone().detach()
else:
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
total_module_scale = scaler / target_normalize_scaler
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to(device, dtype=dtype)
# apply the scaler to the up and down weights
for key in state_dict.keys():
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# do it inplace do params are updated
state_dict[key] *= up_down_scale
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
class ToolkitNetworkMixin:
def __init__(
self: Network,
*args,
train_text_encoder: Optional[bool] = True,
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
**kwargs
):
self.train_text_encoder = train_text_encoder
self.train_unet = train_unet
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
super().__init__(*args, **kwargs)
def get_keymap(self: Network):
if self.is_sdxl:
keymap_tail = 'sdxl'
elif self.is_v2:
keymap_tail = 'sd2'
else:
keymap_tail = 'sd1'
# load keymap
keymap_name = f"stable_diffusion_locon_{keymap_tail}.json"
keymap_path = os.path.join(KEYMAPS_ROOT, keymap_name)
keymap = None
# check if file exists
if os.path.exists(keymap_path):
with open(keymap_path, 'r') as f:
keymap = json.load(f)
return keymap
def save_weights(self: Network, file, dtype=torch.float16, metadata=None):
keymap = self.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
save_dict = OrderedDict()
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
save_key = save_keymap[key] if key in save_keymap else key
save_dict[save_key] = v
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(save_dict, file, metadata)
else:
torch.save(save_dict, file)
def load_weights(self: Network, file):
# allows us to save and load to and from ldm weights
keymap = self.get_keymap()
keymap = {} if keymap is None else keymap
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
load_sd = OrderedDict()
for key, value in weights_sd.items():
load_key = keymap[key] if key in keymap else key
load_sd[load_key] = value
info = self.load_state_dict(load_sd, False)
return info
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
self._multiplier = value
self._update_lora_multiplier()
def _update_lora_multiplier(self: Network):
if self.is_active:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = self._multiplier
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = self._multiplier
else:
if hasattr(self, 'unet_loras'):
for lora in self.unet_loras:
lora.multiplier = 0
if hasattr(self, 'text_encoder_loras'):
for lora in self.text_encoder_loras:
lora.multiplier = 0
# called when the context manager is entered
# ie: with network:
def __enter__(self: Network):
self.is_active = True
self._update_lora_multiplier()
def __exit__(self: Network, exc_type, exc_value, tb):
self.is_active = False
self._update_lora_multiplier()
def force_to(self: Network, device, dtype):
self.to(device, dtype)
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
for lora in loras:
lora.to(device, dtype)
def get_all_modules(self: Network):
loras = []
if hasattr(self, 'unet_loras'):
loras += self.unet_loras
if hasattr(self, 'text_encoder_loras'):
loras += self.text_encoder_loras
return loras
def _update_checkpointing(self: Network):
for module in self.get_all_modules():
if self.is_checkpointing:
module.enable_gradient_checkpointing()
else:
module.disable_gradient_checkpointing()
# def enable_gradient_checkpointing(self: Network):
# # not supported
# self.is_checkpointing = True
# self._update_checkpointing()
#
# def disable_gradient_checkpointing(self: Network):
# # not supported
# self.is_checkpointing = False
# self._update_checkpointing()
@property
def is_normalizing(self: Network) -> bool:
return self._is_normalizing
@is_normalizing.setter
def is_normalizing(self: Network, value: bool):
self._is_normalizing = value
for module in self.get_all_modules():
module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)