diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 70ac192..f2a956f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -18,6 +18,7 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc from toolkit.embedding import Embedding from toolkit.ip_adapter import IPAdapter from toolkit.lora_special import LoRASpecialNetwork +from toolkit.lorm import convert_diffusers_unet_to_lorm from toolkit.lycoris_special import LycorisSpecialNetwork from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer @@ -126,6 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess): is_training_adapter = self.adapter_config is not None and self.adapter_config.train + self.do_lorm = self.get_conf('do_lorm', False) + # get the device state preset based on what we are training self.train_device_state_preset = get_train_sd_device_state_preset( device=self.device_torch, @@ -675,6 +678,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # torch.autograd.set_detect_anomaly(True) # run base process run BaseTrainProcess.run(self) + params = [] ### HOOK ### self.hook_before_model_load() @@ -708,6 +712,15 @@ class BaseSDTrainProcess(BaseTrainProcess): # run base sd process run self.sd.load_model() + if self.do_lorm: + train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27) + for module in train_modules: + p = module.parameters() + for param in p: + param.requires_grad_(True) + params.append(param) + + dtype = get_torch_dtype(self.train_config.dtype) # model is loaded from BaseSDProcess @@ -767,7 +780,6 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.datasets_reg is not None: self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd) - params = [] if not self.is_fine_tuning: if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? @@ -903,8 +915,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) - params = self.get_params() - if not params: + + # params = self.get_params() + if len(params) == 0: # will only return savable weights and ones with grad params = self.sd.prepare_optimizer_params( unet=self.train_config.train_unet, diff --git a/toolkit/lorm.py b/toolkit/lorm.py new file mode 100644 index 0000000..b6acb6d --- /dev/null +++ b/toolkit/lorm.py @@ -0,0 +1,424 @@ +from typing import Union, Tuple, Literal, Optional + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from torch import Tensor +from tqdm import tqdm + +conv = nn.Conv2d +lin = nn.Linear +_size_2_t = Union[int, Tuple[int, int]] + +ExtractMode = Union[ + 'fixed', + 'threshold', + 'ratio', + 'quantile', + 'percentage' +] + +LINEAR_MODULES = [ + 'Linear', + 'LoRACompatibleLinear' +] +CONV_MODULES = [ + # 'Conv2d', + # 'LoRACompatibleConv' +] + +UNET_TARGET_REPLACE_MODULE = [ + "Transformer2DModel", + # "BasicTransformerBlock", + # "ResnetBlock2D", + "Downsample2D", + "Upsample2D", +] + +UNET_TARGET_REPLACE_NAME = [ + "conv_in", + "conv_out", + "time_embedding.linear_1", + "time_embedding.linear_2", +] + +UNET_MODULES_TO_AVOID = [ +] + + +# Low Rank Convolution +class LoRMCon2d(nn.Module): + def __init__( + self, + in_channels: int, + lorm_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 'same', + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_channels = in_channels + self.lorm_channels = lorm_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.padding_mode = padding_mode + + self.down = nn.Conv2d( + in_channels=in_channels, + out_channels=lorm_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + padding_mode=padding_mode, + device=device, + dtype=dtype + ) + + # Kernel size on the up is always 1x1. + # I don't think you could calculate a dual 3x3, or I can't at least + + self.up = nn.Conv2d( + in_channels=lorm_channels, + out_channels=out_channels, + kernel_size=(1, 1), + stride=1, + padding='same', + dilation=1, + groups=1, + bias=bias, + padding_mode='zeros', + device=device, + dtype=dtype + ) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + x = input + x = self.down(x) + x = self.up(x) + return x + + +class LoRMLinear(nn.Module): + def __init__( + self, + in_features: int, + lorm_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_features = in_features + self.lorm_features = lorm_features + self.out_features = out_features + + self.down = nn.Linear( + in_features=in_features, + out_features=lorm_features, + bias=False, + device=device, + dtype=dtype + + ) + self.up = nn.Linear( + in_features=lorm_features, + out_features=out_features, + bias=bias, + # bias=True, + device=device, + dtype=dtype + ) + + def forward(self, input: Tensor, *args, **kwargs) -> Tensor: + x = input + x = self.down(x) + x = self.up(x) + return x + + +def extract_conv( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu' +) -> Tuple[Tensor, Tensor, int, Tensor]: + weight = weight.to(device) + out_ch, in_ch, kernel_size, _ = weight.shape + + U, S, Vh = torch.linalg.svd(weight.reshape(out_ch, -1)) + if mode == 'percentage': + assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. + original_params = out_ch * in_ch * kernel_size * kernel_size + desired_params = mode_param * original_params + # Solve for lora_rank from the equation + lora_rank = int(desired_params / (in_ch * kernel_size * kernel_size + out_ch)) + elif mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param).item() + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s).item() + elif mode == 'quantile' or mode == 'percentile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum).item() + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + lora_rank = int(out_ch / 2) + print(f"rank is higher than it should be") + # return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - (U @ Vh).reshape(out_ch, in_ch, kernel_size, kernel_size)).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch, kernel_size, kernel_size).detach() + extract_weight_B = U.reshape(out_ch, lora_rank, 1, 1).detach() + del U, S, Vh, weight + return extract_weight_A, extract_weight_B, lora_rank, diff + + +def extract_linear( + weight: Union[torch.Tensor, nn.Parameter], + mode='fixed', + mode_param=0, + device='cpu', +) -> Tuple[Tensor, Tensor, int, Tensor]: + weight = weight.to(device) + out_ch, in_ch = weight.shape + + U, S, Vh = torch.linalg.svd(weight) + + if mode == 'percentage': + assert 0 <= mode_param <= 1 # Ensure it's a valid percentage. + desired_params = mode_param * out_ch * in_ch + # Solve for lora_rank from the equation + lora_rank = int(desired_params / (in_ch + out_ch)) + elif mode == 'fixed': + lora_rank = mode_param + elif mode == 'threshold': + assert mode_param >= 0 + lora_rank = torch.sum(S > mode_param).item() + elif mode == 'ratio': + assert 1 >= mode_param >= 0 + min_s = torch.max(S) * mode_param + lora_rank = torch.sum(S > min_s).item() + elif mode == 'quantile': + assert 1 >= mode_param >= 0 + s_cum = torch.cumsum(S, dim=0) + min_cum_sum = mode_param * torch.sum(S) + lora_rank = torch.sum(s_cum < min_cum_sum).item() + else: + raise NotImplementedError('Extract mode should be "fixed", "threshold", "ratio" or "quantile"') + lora_rank = max(1, lora_rank) + lora_rank = min(out_ch, in_ch, lora_rank) + if lora_rank >= out_ch / 2: + # print(f"rank is higher than it should be") + lora_rank = int(out_ch / 2) + # return weight, 'full' + + U = U[:, :lora_rank] + S = S[:lora_rank] + U = U @ torch.diag(S) + Vh = Vh[:lora_rank, :] + + diff = (weight - U @ Vh).detach() + extract_weight_A = Vh.reshape(lora_rank, in_ch).detach() + extract_weight_B = U.reshape(out_ch, lora_rank).detach() + del U, S, Vh, weight + return extract_weight_A, extract_weight_B, lora_rank, diff + + +def replace_module_by_path(network, name, module): + """Replace a module in a network by its name.""" + name_parts = name.split('.') + current_module = network + for part in name_parts[:-1]: + current_module = getattr(current_module, part) + try: + setattr(current_module, name_parts[-1], module) + except Exception as e: + print(e) + + +def count_parameters(module): + return sum(p.numel() for p in module.parameters()) + + +def compute_optimal_bias(original_module, linear_down, linear_up, X): + Y_original = original_module(X) + Y_approx = linear_up(linear_down(X)) + E = Y_original - Y_approx + + optimal_bias = E.mean(dim=0) + + return optimal_bias + + +@torch.no_grad() +def convert_diffusers_unet_to_lorm( + unet: UNet2DConditionModel, + extract_mode: ExtractMode = "percentile", + mode_param: Union[int, float] = 0.5, + parameter_threshold: int = 500000, + # parameter_threshold: int = 1500000 +): + print('Converting UNet to LoRM UNet') + start_num_params = count_parameters(unet) + named_modules = list(unet.named_modules()) + + num_replaced = 0 + + pbar = tqdm(total=len(named_modules), desc="UNet -> LoRM UNet") + layer_names_replaced = [] + converted_modules = [] + ignore_if_contains = [ + 'proj_out', 'proj_in', + ] + def format_with_commas(n): + return f"{n:,}" + + for name, module in named_modules: + module_name = module.__class__.__name__ + if module_name in UNET_TARGET_REPLACE_MODULE: + for child_name, child_module in module.named_modules(): + new_module: Union[LoRMCon2d, LoRMLinear, None] = None + # if child name includes attn, skip it + combined_name = combined_name = f"{name}.{child_name}" + # if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None: + # pass + if any([word in child_name for word in ignore_if_contains]): + pass + + elif child_module.__class__.__name__ in LINEAR_MODULES: + if count_parameters(child_module) > parameter_threshold: + + dtype = child_module.weight.dtype + # extract and convert + down_weight, up_weight, lora_dim, diff = extract_linear( + weight=child_module.weight.clone().detach().float(), + mode=extract_mode, + mode_param=mode_param, + device=child_module.weight.device, + ) + down_weight = down_weight.to(dtype=dtype) + up_weight = up_weight.to(dtype=dtype) + bias_weight = None + if child_module.bias is not None: + bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) + # linear layer weights = (out_features, in_features) + new_module = LoRMLinear( + in_features=down_weight.shape[1], + lorm_features=lora_dim, + out_features=up_weight.shape[0], + bias=bias_weight is not None, + device=down_weight.device, + dtype=down_weight.dtype + ) + + # replace the weights + new_module.down.weight.data = down_weight + new_module.up.weight.data = up_weight + if bias_weight is not None: + new_module.up.bias.data = bias_weight + # else: + # new_module.up.bias.data = torch.zeros_like(new_module.up.bias.data) + + # bias_correction = compute_optimal_bias( + # child_module, + # new_module.down, + # new_module.up, + # torch.randn((1000, down_weight.shape[1])).to(device=down_weight.device, dtype=dtype) + # ) + # new_module.up.bias.data += bias_correction + + elif child_module.__class__.__name__ in CONV_MODULES: + if count_parameters(child_module) > parameter_threshold: + dtype = child_module.weight.dtype + down_weight, up_weight, lora_dim, diff = extract_conv( + weight=child_module.weight.clone().detach().float(), + mode=extract_mode, + mode_param=mode_param, + device=child_module.weight.device, + ) + down_weight = down_weight.to(dtype=dtype) + up_weight = up_weight.to(dtype=dtype) + bias_weight = None + if child_module.bias is not None: + bias_weight = child_module.bias.data.clone().detach().to(dtype=dtype) + + new_module = LoRMCon2d( + in_channels=down_weight.shape[1], + lorm_channels=lora_dim, + out_channels=up_weight.shape[0], + kernel_size=child_module.kernel_size, + dilation=child_module.dilation, + padding=child_module.padding, + padding_mode=child_module.padding_mode, + stride=child_module.stride, + bias=bias_weight is not None, + device=down_weight.device, + dtype=down_weight.dtype + ) + # replace the weights + new_module.down.weight.data = down_weight + new_module.up.weight.data = up_weight + if bias_weight is not None: + new_module.up.bias.data = bias_weight + + if new_module: + combined_name = f"{name}.{child_name}" + replace_module_by_path(unet, combined_name, new_module) + converted_modules.append(new_module) + num_replaced += 1 + layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}") + + pbar.update(1) + pbar.close() + end_num_params = count_parameters(unet) + + start_formatted = format_with_commas(start_num_params) + end_formatted = format_with_commas(end_num_params) + num_replaced_formatted = format_with_commas(num_replaced) + + width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) + + def sorting_key(s): + # Extract the number part, remove commas, and convert to integer + return int(s.split("-")[1].strip().replace(",", "")) + + sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) + + for layer_name in sorted_layer_names_replaced: + print(layer_name) + + print(f"Convert UNet result:") + print(f" - converted: {num_replaced:>{width},} modules") + print(f" - start: {start_num_params:>{width},} params") + print(f" - end: {end_num_params:>{width},} params") + + return converted_modules diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index c9a19fa..ee71903 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -456,7 +456,7 @@ class StableDiffusion: **extra ).images[0] - gen_config.save_image(img) + gen_config.save_image(img, i) # clear pipeline and cache to reduce vram usage del pipeline