mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 13:39:56 +00:00
Added lorm. WIP
This commit is contained in:
@@ -18,6 +18,7 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -126,6 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
self.do_lorm = self.get_conf('do_lorm', False)
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
@@ -675,6 +678,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
params = []
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
@@ -708,6 +712,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.do_lorm:
|
||||
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
|
||||
for module in train_modules:
|
||||
p = module.parameters()
|
||||
for param in p:
|
||||
param.requires_grad_(True)
|
||||
params.append(param)
|
||||
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -767,7 +780,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
params = []
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
@@ -903,8 +915,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
|
||||
# params = self.get_params()
|
||||
if len(params) == 0:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
|
||||
424
toolkit/lorm.py
Normal file
424
toolkit/lorm.py
Normal file
@@ -0,0 +1,424 @@
|
||||
from typing import Union, Tuple, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers import UNet2DConditionModel
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
conv = nn.Conv2d
|
||||
lin = nn.Linear
|
||||
_size_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
ExtractMode = Union[
|
||||
'fixed',
|
||||
'threshold',
|
||||
'ratio',
|
||||
'quantile',
|
||||
'percentage'
|
||||
]
|
||||
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
]
|
||||
CONV_MODULES = [
|
||||
# 'Conv2d',
|
||||
# 'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
# "BasicTransformerBlock",
|
||||
# "ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
]
|
||||
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
"time_embedding.linear_1",
|
||||
"time_embedding.linear_2",
|
||||
]
|
||||
|
||||
UNET_MODULES_TO_AVOID = [
|
||||
]
|
||||
|
||||
|
||||
# Low Rank Convolution
|
||||
class LoRMCon2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
lorm_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: _size_2_t,
|
||||
stride: _size_2_t = 1,
|
||||
padding: Union[str, _size_2_t] = 'same',
|
||||
dilation: _size_2_t = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.lorm_channels = lorm_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.down = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=lorm_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
padding_mode=padding_mode,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
# Kernel size on the up is always 1x1.
|
||||
# I don't think you could calculate a dual 3x3, or I can't at least
|
||||
|
||||
self.up = nn.Conv2d(
|
||||
in_channels=lorm_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(1, 1),
|
||||
stride=1,
|
||||
padding='same',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=bias,
|
||||
padding_mode='zeros',
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
|
||||
x = input
|
||||
x = self.down(x)
|
||||
x = self.up(x)
|
||||
return x
|
||||
|
||||
|
||||
class LoRMLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
lorm_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.lorm_features = lorm_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.down = nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=lorm_features,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
|
||||
)
|
||||
self.up = nn.Linear(
|
||||
in_features=lorm_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
# bias=True,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
|
||||
x = input
|
||||
x = self.down(x)
|
||||
x = self.up(x)
|
||||
return x
|
||||
|
||||
|
||||
def extract_conv(
|
||||
weight: Union[torch.Tensor, nn.Parameter],
|
||||
mode='fixed',
|
||||
mode_param=0,
|
||||
device='cpu'
|
||||
) -> Tuple[Tensor, Tensor, int, Tensor]:
|
||||
weight = weight.to(device)
|
||||
out_ch, in_ch, kernel_size, _ = weight.shape
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1))
|
||||
if mode == 'percentage':
|
||||
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
|
||||
original_params = out_ch * in_ch * kernel_size * kernel_size
|
||||
desired_params = mode_param * original_params
|
||||
# Solve for lora_rank from the equation
|
||||
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch))
|
||||
elif mode == 'fixed':
|
||||
lora_rank = mode_param
|
||||
elif mode == 'threshold':
|
||||
assert mode_param >= 0
|
||||
lora_rank = torch.sum(S > mode_param).item()
|
||||
elif mode == 'ratio':
|
||||
assert 1 >= mode_param >= 0
|
||||
min_s = torch.max(S) * mode_param
|
||||
lora_rank = torch.sum(S > min_s).item()
|
||||
elif mode == 'quantile' or mode == 'percentile':
|
||||
assert 1 >= mode_param >= 0
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = mode_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
||||
else:
|
||||
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||
if lora_rank >= out_ch / 2:
|
||||
lora_rank = int(out_ch / 2)
|
||||
print(f"rank is higher than it should be")
|
||||
# return weight, 'full'
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
|
||||
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
||||
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
||||
del U, S, Vh, weight
|
||||
return extract_weight_A, extract_weight_B, lora_rank, diff
|
||||
|
||||
|
||||
def extract_linear(
|
||||
weight: Union[torch.Tensor, nn.Parameter],
|
||||
mode='fixed',
|
||||
mode_param=0,
|
||||
device='cpu',
|
||||
) -> Tuple[Tensor, Tensor, int, Tensor]:
|
||||
weight = weight.to(device)
|
||||
out_ch, in_ch = weight.shape
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight)
|
||||
|
||||
if mode == 'percentage':
|
||||
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
|
||||
desired_params = mode_param * out_ch * in_ch
|
||||
# Solve for lora_rank from the equation
|
||||
lora_rank = int(desired_params / (in_ch + out_ch))
|
||||
elif mode == 'fixed':
|
||||
lora_rank = mode_param
|
||||
elif mode == 'threshold':
|
||||
assert mode_param >= 0
|
||||
lora_rank = torch.sum(S > mode_param).item()
|
||||
elif mode == 'ratio':
|
||||
assert 1 >= mode_param >= 0
|
||||
min_s = torch.max(S) * mode_param
|
||||
lora_rank = torch.sum(S > min_s).item()
|
||||
elif mode == 'quantile':
|
||||
assert 1 >= mode_param >= 0
|
||||
s_cum = torch.cumsum(S, dim=0)
|
||||
min_cum_sum = mode_param * torch.sum(S)
|
||||
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
||||
else:
|
||||
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
||||
lora_rank = max(1, lora_rank)
|
||||
lora_rank = min(out_ch, in_ch, lora_rank)
|
||||
if lora_rank >= out_ch / 2:
|
||||
# print(f"rank is higher than it should be")
|
||||
lora_rank = int(out_ch / 2)
|
||||
# return weight, 'full'
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
diff = (weight - U @ Vh).detach()
|
||||
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
||||
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
||||
del U, S, Vh, weight
|
||||
return extract_weight_A, extract_weight_B, lora_rank, diff
|
||||
|
||||
|
||||
def replace_module_by_path(network, name, module):
|
||||
"""Replace a module in a network by its name."""
|
||||
name_parts = name.split('.')
|
||||
current_module = network
|
||||
for part in name_parts[:-1]:
|
||||
current_module = getattr(current_module, part)
|
||||
try:
|
||||
setattr(current_module, name_parts[-1], module)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
def count_parameters(module):
|
||||
return sum(p.numel() for p in module.parameters())
|
||||
|
||||
|
||||
def compute_optimal_bias(original_module, linear_down, linear_up, X):
|
||||
Y_original = original_module(X)
|
||||
Y_approx = linear_up(linear_down(X))
|
||||
E = Y_original - Y_approx
|
||||
|
||||
optimal_bias = E.mean(dim=0)
|
||||
|
||||
return optimal_bias
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusers_unet_to_lorm(
|
||||
unet: UNet2DConditionModel,
|
||||
extract_mode: ExtractMode = "percentile",
|
||||
mode_param: Union[int, float] = 0.5,
|
||||
parameter_threshold: int = 500000,
|
||||
# parameter_threshold: int = 1500000
|
||||
):
|
||||
print('Converting UNet to LoRM UNet')
|
||||
start_num_params = count_parameters(unet)
|
||||
named_modules = list(unet.named_modules())
|
||||
|
||||
num_replaced = 0
|
||||
|
||||
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet")
|
||||
layer_names_replaced = []
|
||||
converted_modules = []
|
||||
ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
def format_with_commas(n):
|
||||
return f"{n:,}"
|
||||
|
||||
for name, module in named_modules:
|
||||
module_name = module.__class__.__name__
|
||||
if module_name in UNET_TARGET_REPLACE_MODULE:
|
||||
for child_name, child_module in module.named_modules():
|
||||
new_module: Union[LoRMCon2d, LoRMLinear, None] = None
|
||||
# if child name includes attn, skip it
|
||||
combined_name = combined_name = f"{name}.{child_name}"
|
||||
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
|
||||
# pass
|
||||
if any([word in child_name for word in ignore_if_contains]):
|
||||
pass
|
||||
|
||||
elif child_module.__class__.__name__ in LINEAR_MODULES:
|
||||
if count_parameters(child_module) > parameter_threshold:
|
||||
|
||||
dtype = child_module.weight.dtype
|
||||
# extract and convert
|
||||
down_weight, up_weight, lora_dim, diff = extract_linear(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
up_weight = up_weight.to(dtype=dtype)
|
||||
bias_weight = None
|
||||
if child_module.bias is not None:
|
||||
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
|
||||
# linear layer weights = (out_features, in_features)
|
||||
new_module = LoRMLinear(
|
||||
in_features=down_weight.shape[1],
|
||||
lorm_features=lora_dim,
|
||||
out_features=up_weight.shape[0],
|
||||
bias=bias_weight is not None,
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype
|
||||
)
|
||||
|
||||
# replace the weights
|
||||
new_module.down.weight.data = down_weight
|
||||
new_module.up.weight.data = up_weight
|
||||
if bias_weight is not None:
|
||||
new_module.up.bias.data = bias_weight
|
||||
# else:
|
||||
# new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data)
|
||||
|
||||
# bias_correction = compute_optimal_bias(
|
||||
# child_module,
|
||||
# new_module.down,
|
||||
# new_module.up,
|
||||
# torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype)
|
||||
# )
|
||||
# new_module.up.bias.data += bias_correction
|
||||
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
if count_parameters(child_module) > parameter_threshold:
|
||||
dtype = child_module.weight.dtype
|
||||
down_weight, up_weight, lora_dim, diff = extract_conv(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
up_weight = up_weight.to(dtype=dtype)
|
||||
bias_weight = None
|
||||
if child_module.bias is not None:
|
||||
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
|
||||
|
||||
new_module = LoRMCon2d(
|
||||
in_channels=down_weight.shape[1],
|
||||
lorm_channels=lora_dim,
|
||||
out_channels=up_weight.shape[0],
|
||||
kernel_size=child_module.kernel_size,
|
||||
dilation=child_module.dilation,
|
||||
padding=child_module.padding,
|
||||
padding_mode=child_module.padding_mode,
|
||||
stride=child_module.stride,
|
||||
bias=bias_weight is not None,
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype
|
||||
)
|
||||
# replace the weights
|
||||
new_module.down.weight.data = down_weight
|
||||
new_module.up.weight.data = up_weight
|
||||
if bias_weight is not None:
|
||||
new_module.up.bias.data = bias_weight
|
||||
|
||||
if new_module:
|
||||
combined_name = f"{name}.{child_name}"
|
||||
replace_module_by_path(unet, combined_name, new_module)
|
||||
converted_modules.append(new_module)
|
||||
num_replaced += 1
|
||||
layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
||||
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
end_num_params = count_parameters(unet)
|
||||
|
||||
start_formatted = format_with_commas(start_num_params)
|
||||
end_formatted = format_with_commas(end_num_params)
|
||||
num_replaced_formatted = format_with_commas(num_replaced)
|
||||
|
||||
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
||||
|
||||
def sorting_key(s):
|
||||
# Extract the number part, remove commas, and convert to integer
|
||||
return int(s.split("-")[1].strip().replace(",", ""))
|
||||
|
||||
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
|
||||
|
||||
for layer_name in sorted_layer_names_replaced:
|
||||
print(layer_name)
|
||||
|
||||
print(f"Convert UNet result:")
|
||||
print(f" - converted: {num_replaced:>{width},} modules")
|
||||
print(f" - start: {start_num_params:>{width},} params")
|
||||
print(f" - end: {end_num_params:>{width},} params")
|
||||
|
||||
return converted_modules
|
||||
@@ -456,7 +456,7 @@ class StableDiffusion:
|
||||
**extra
|
||||
).images[0]
|
||||
|
||||
gen_config.save_image(img)
|
||||
gen_config.save_image(img, i)
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
|
||||
Reference in New Issue
Block a user