mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
462 lines
16 KiB
Python
462 lines
16 KiB
Python
from typing import Union, Tuple, Literal, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from diffusers import UNet2DConditionModel
|
|
from torch import Tensor
|
|
from tqdm import tqdm
|
|
|
|
from toolkit.config_modules import LoRMConfig
|
|
|
|
conv = nn.Conv2d
|
|
lin = nn.Linear
|
|
_size_2_t = Union[int, Tuple[int, int]]
|
|
|
|
ExtractMode = Union[
|
|
'fixed',
|
|
'threshold',
|
|
'ratio',
|
|
'quantile',
|
|
'percentage'
|
|
]
|
|
|
|
LINEAR_MODULES = [
|
|
'Linear',
|
|
'LoRACompatibleLinear'
|
|
]
|
|
CONV_MODULES = [
|
|
# 'Conv2d',
|
|
# 'LoRACompatibleConv'
|
|
]
|
|
|
|
UNET_TARGET_REPLACE_MODULE = [
|
|
"Transformer2DModel",
|
|
# "ResnetBlock2D",
|
|
"Downsample2D",
|
|
"Upsample2D",
|
|
]
|
|
|
|
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
|
|
|
|
UNET_TARGET_REPLACE_NAME = [
|
|
"conv_in",
|
|
"conv_out",
|
|
"time_embedding.linear_1",
|
|
"time_embedding.linear_2",
|
|
]
|
|
|
|
UNET_MODULES_TO_AVOID = [
|
|
]
|
|
|
|
|
|
# Low Rank Convolution
|
|
class LoRMCon2d(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
lorm_channels: int,
|
|
out_channels: int,
|
|
kernel_size: _size_2_t,
|
|
stride: _size_2_t = 1,
|
|
padding: Union[str, _size_2_t] = 'same',
|
|
dilation: _size_2_t = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: str = 'zeros',
|
|
device=None,
|
|
dtype=None
|
|
) -> None:
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.lorm_channels = lorm_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.padding = padding
|
|
self.dilation = dilation
|
|
self.groups = groups
|
|
self.padding_mode = padding_mode
|
|
|
|
self.down = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=lorm_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=False,
|
|
padding_mode=padding_mode,
|
|
device=device,
|
|
dtype=dtype
|
|
)
|
|
|
|
# Kernel size on the up is always 1x1.
|
|
# I don't think you could calculate a dual 3x3, or I can't at least
|
|
|
|
self.up = nn.Conv2d(
|
|
in_channels=lorm_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=(1, 1),
|
|
stride=1,
|
|
padding='same',
|
|
dilation=1,
|
|
groups=1,
|
|
bias=bias,
|
|
padding_mode='zeros',
|
|
device=device,
|
|
dtype=dtype
|
|
)
|
|
|
|
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
|
|
x = input
|
|
x = self.down(x)
|
|
x = self.up(x)
|
|
return x
|
|
|
|
|
|
class LoRMLinear(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
lorm_features: int,
|
|
out_features: int,
|
|
bias: bool = True,
|
|
device=None,
|
|
dtype=None
|
|
) -> None:
|
|
super().__init__()
|
|
self.in_features = in_features
|
|
self.lorm_features = lorm_features
|
|
self.out_features = out_features
|
|
|
|
self.down = nn.Linear(
|
|
in_features=in_features,
|
|
out_features=lorm_features,
|
|
bias=False,
|
|
device=device,
|
|
dtype=dtype
|
|
|
|
)
|
|
self.up = nn.Linear(
|
|
in_features=lorm_features,
|
|
out_features=out_features,
|
|
bias=bias,
|
|
# bias=True,
|
|
device=device,
|
|
dtype=dtype
|
|
)
|
|
|
|
def forward(self, input: Tensor, *args, **kwargs) -> Tensor:
|
|
x = input
|
|
x = self.down(x)
|
|
x = self.up(x)
|
|
return x
|
|
|
|
|
|
def extract_conv(
|
|
weight: Union[torch.Tensor, nn.Parameter],
|
|
mode='fixed',
|
|
mode_param=0,
|
|
device='cpu'
|
|
) -> Tuple[Tensor, Tensor, int, Tensor]:
|
|
weight = weight.to(device)
|
|
out_ch, in_ch, kernel_size, _ = weight.shape
|
|
|
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1))
|
|
if mode == 'percentage':
|
|
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
|
|
original_params = out_ch * in_ch * kernel_size * kernel_size
|
|
desired_params = mode_param * original_params
|
|
# Solve for lora_rank from the equation
|
|
lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch))
|
|
elif mode == 'fixed':
|
|
lora_rank = mode_param
|
|
elif mode == 'threshold':
|
|
assert mode_param >= 0
|
|
lora_rank = torch.sum(S > mode_param).item()
|
|
elif mode == 'ratio':
|
|
assert 1 >= mode_param >= 0
|
|
min_s = torch.max(S) * mode_param
|
|
lora_rank = torch.sum(S > min_s).item()
|
|
elif mode == 'quantile' or mode == 'percentile':
|
|
assert 1 >= mode_param >= 0
|
|
s_cum = torch.cumsum(S, dim=0)
|
|
min_cum_sum = mode_param * torch.sum(S)
|
|
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
|
else:
|
|
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
|
lora_rank = max(1, lora_rank)
|
|
lora_rank = min(out_ch, in_ch, lora_rank)
|
|
if lora_rank >= out_ch / 2:
|
|
lora_rank = int(out_ch / 2)
|
|
print(f"rank is higher than it should be")
|
|
# print(f"Skipping layer as determined rank is too high")
|
|
# return None, None, None, None
|
|
# return weight, 'full'
|
|
|
|
U = U[:, :lora_rank]
|
|
S = S[:lora_rank]
|
|
U = U @ torch.diag(S)
|
|
Vh = Vh[:lora_rank, :]
|
|
|
|
diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach()
|
|
extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach()
|
|
extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach()
|
|
del U, S, Vh, weight
|
|
return extract_weight_A, extract_weight_B, lora_rank, diff
|
|
|
|
|
|
def extract_linear(
|
|
weight: Union[torch.Tensor, nn.Parameter],
|
|
mode='fixed',
|
|
mode_param=0,
|
|
device='cpu',
|
|
) -> Tuple[Tensor, Tensor, int, Tensor]:
|
|
weight = weight.to(device)
|
|
out_ch, in_ch = weight.shape
|
|
|
|
U, S, Vh = torch.linalg.svd(weight)
|
|
|
|
if mode == 'percentage':
|
|
assert 0 <= mode_param <= 1 # Ensure it's a valid percentage.
|
|
desired_params = mode_param * out_ch * in_ch
|
|
# Solve for lora_rank from the equation
|
|
lora_rank = int(desired_params / (in_ch + out_ch))
|
|
elif mode == 'fixed':
|
|
lora_rank = mode_param
|
|
elif mode == 'threshold':
|
|
assert mode_param >= 0
|
|
lora_rank = torch.sum(S > mode_param).item()
|
|
elif mode == 'ratio':
|
|
assert 1 >= mode_param >= 0
|
|
min_s = torch.max(S) * mode_param
|
|
lora_rank = torch.sum(S > min_s).item()
|
|
elif mode == 'quantile':
|
|
assert 1 >= mode_param >= 0
|
|
s_cum = torch.cumsum(S, dim=0)
|
|
min_cum_sum = mode_param * torch.sum(S)
|
|
lora_rank = torch.sum(s_cum < min_cum_sum).item()
|
|
else:
|
|
raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"')
|
|
lora_rank = max(1, lora_rank)
|
|
lora_rank = min(out_ch, in_ch, lora_rank)
|
|
if lora_rank >= out_ch / 2:
|
|
# print(f"rank is higher than it should be")
|
|
lora_rank = int(out_ch / 2)
|
|
# return weight, 'full'
|
|
# print(f"Skipping layer as determined rank is too high")
|
|
# return None, None, None, None
|
|
|
|
U = U[:, :lora_rank]
|
|
S = S[:lora_rank]
|
|
U = U @ torch.diag(S)
|
|
Vh = Vh[:lora_rank, :]
|
|
|
|
diff = (weight - U @ Vh).detach()
|
|
extract_weight_A = Vh.reshape(lora_rank, in_ch).detach()
|
|
extract_weight_B = U.reshape(out_ch, lora_rank).detach()
|
|
del U, S, Vh, weight
|
|
return extract_weight_A, extract_weight_B, lora_rank, diff
|
|
|
|
|
|
def replace_module_by_path(network, name, module):
|
|
"""Replace a module in a network by its name."""
|
|
name_parts = name.split('.')
|
|
current_module = network
|
|
for part in name_parts[:-1]:
|
|
current_module = getattr(current_module, part)
|
|
try:
|
|
setattr(current_module, name_parts[-1], module)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
|
|
def count_parameters(module):
|
|
return sum(p.numel() for p in module.parameters())
|
|
|
|
|
|
def compute_optimal_bias(original_module, linear_down, linear_up, X):
|
|
Y_original = original_module(X)
|
|
Y_approx = linear_up(linear_down(X))
|
|
E = Y_original - Y_approx
|
|
|
|
optimal_bias = E.mean(dim=0)
|
|
|
|
return optimal_bias
|
|
|
|
|
|
def format_with_commas(n):
|
|
return f"{n:,}"
|
|
|
|
|
|
def print_lorm_extract_details(
|
|
start_num_params: int,
|
|
end_num_params: int,
|
|
num_replaced: int,
|
|
):
|
|
start_formatted = format_with_commas(start_num_params)
|
|
end_formatted = format_with_commas(end_num_params)
|
|
num_replaced_formatted = format_with_commas(num_replaced)
|
|
|
|
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
|
|
|
print(f"Convert UNet result:")
|
|
print(f" - converted: {num_replaced:>{width},} modules")
|
|
print(f" - start: {start_num_params:>{width},} params")
|
|
print(f" - end: {end_num_params:>{width},} params")
|
|
|
|
|
|
lorm_ignore_if_contains = [
|
|
'proj_out', 'proj_in',
|
|
]
|
|
|
|
lorm_parameter_threshold = 1000000
|
|
|
|
|
|
@torch.no_grad()
|
|
def convert_diffusers_unet_to_lorm(
|
|
unet: UNet2DConditionModel,
|
|
config: LoRMConfig,
|
|
):
|
|
print('Converting UNet to LoRM UNet')
|
|
start_num_params = count_parameters(unet)
|
|
named_modules = list(unet.named_modules())
|
|
|
|
num_replaced = 0
|
|
|
|
pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet")
|
|
layer_names_replaced = []
|
|
converted_modules = []
|
|
ignore_if_contains = [
|
|
'proj_out', 'proj_in',
|
|
]
|
|
|
|
for name, module in named_modules:
|
|
module_name = module.__class__.__name__
|
|
if module_name in UNET_TARGET_REPLACE_MODULE:
|
|
for child_name, child_module in module.named_modules():
|
|
new_module: Union[LoRMCon2d, LoRMLinear, None] = None
|
|
# if child name includes attn, skip it
|
|
combined_name = combined_name = f"{name}.{child_name}"
|
|
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
|
|
# pass
|
|
|
|
lorm_config = config.get_config_for_module(combined_name)
|
|
|
|
extract_mode = lorm_config.extract_mode
|
|
extract_mode_param = lorm_config.extract_mode_param
|
|
parameter_threshold = lorm_config.parameter_threshold
|
|
|
|
if any([word in child_name for word in ignore_if_contains]):
|
|
pass
|
|
|
|
elif child_module.__class__.__name__ in LINEAR_MODULES:
|
|
if count_parameters(child_module) > parameter_threshold:
|
|
|
|
# dtype = child_module.weight.dtype
|
|
dtype = torch.float32
|
|
# extract and convert
|
|
down_weight, up_weight, lora_dim, diff = extract_linear(
|
|
weight=child_module.weight.clone().detach().float(),
|
|
mode=extract_mode,
|
|
mode_param=extract_mode_param,
|
|
device=child_module.weight.device,
|
|
)
|
|
if down_weight is None:
|
|
continue
|
|
down_weight = down_weight.to(dtype=dtype)
|
|
up_weight = up_weight.to(dtype=dtype)
|
|
bias_weight = None
|
|
if child_module.bias is not None:
|
|
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
|
|
# linear layer weights = (out_features, in_features)
|
|
new_module = LoRMLinear(
|
|
in_features=down_weight.shape[1],
|
|
lorm_features=lora_dim,
|
|
out_features=up_weight.shape[0],
|
|
bias=bias_weight is not None,
|
|
device=down_weight.device,
|
|
dtype=down_weight.dtype
|
|
)
|
|
|
|
# replace the weights
|
|
new_module.down.weight.data = down_weight
|
|
new_module.up.weight.data = up_weight
|
|
if bias_weight is not None:
|
|
new_module.up.bias.data = bias_weight
|
|
# else:
|
|
# new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data)
|
|
|
|
# bias_correction = compute_optimal_bias(
|
|
# child_module,
|
|
# new_module.down,
|
|
# new_module.up,
|
|
# torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype)
|
|
# )
|
|
# new_module.up.bias.data += bias_correction
|
|
|
|
elif child_module.__class__.__name__ in CONV_MODULES:
|
|
if count_parameters(child_module) > parameter_threshold:
|
|
dtype = child_module.weight.dtype
|
|
down_weight, up_weight, lora_dim, diff = extract_conv(
|
|
weight=child_module.weight.clone().detach().float(),
|
|
mode=extract_mode,
|
|
mode_param=extract_mode_param,
|
|
device=child_module.weight.device,
|
|
)
|
|
if down_weight is None:
|
|
continue
|
|
down_weight = down_weight.to(dtype=dtype)
|
|
up_weight = up_weight.to(dtype=dtype)
|
|
bias_weight = None
|
|
if child_module.bias is not None:
|
|
bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype)
|
|
|
|
new_module = LoRMCon2d(
|
|
in_channels=down_weight.shape[1],
|
|
lorm_channels=lora_dim,
|
|
out_channels=up_weight.shape[0],
|
|
kernel_size=child_module.kernel_size,
|
|
dilation=child_module.dilation,
|
|
padding=child_module.padding,
|
|
padding_mode=child_module.padding_mode,
|
|
stride=child_module.stride,
|
|
bias=bias_weight is not None,
|
|
device=down_weight.device,
|
|
dtype=down_weight.dtype
|
|
)
|
|
# replace the weights
|
|
new_module.down.weight.data = down_weight
|
|
new_module.up.weight.data = up_weight
|
|
if bias_weight is not None:
|
|
new_module.up.bias.data = bias_weight
|
|
|
|
if new_module:
|
|
combined_name = f"{name}.{child_name}"
|
|
replace_module_by_path(unet, combined_name, new_module)
|
|
converted_modules.append(new_module)
|
|
num_replaced += 1
|
|
layer_names_replaced.append(
|
|
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
|
|
|
pbar.update(1)
|
|
pbar.close()
|
|
end_num_params = count_parameters(unet)
|
|
|
|
def sorting_key(s):
|
|
# Extract the number part, remove commas, and convert to integer
|
|
return int(s.split("-")[1].strip().replace(",", ""))
|
|
|
|
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
|
|
for layer_name in sorted_layer_names_replaced:
|
|
print(layer_name)
|
|
|
|
print_lorm_extract_details(
|
|
start_num_params=start_num_params,
|
|
end_num_params=end_num_params,
|
|
num_replaced=num_replaced,
|
|
)
|
|
|
|
return converted_modules
|