mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Improved lorm extraction and training
This commit is contained in:
@@ -1,17 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import weakref
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule']
|
||||
@@ -26,6 +32,15 @@ CONV_MODULES = [
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
ExtractMode = Union[
|
||||
'existing'
|
||||
'fixed',
|
||||
'threshold',
|
||||
'ratio',
|
||||
'quantile',
|
||||
'percentage'
|
||||
]
|
||||
|
||||
|
||||
def broadcast_and_multiply(tensor, multiplier):
|
||||
# Determine the number of dimensions required
|
||||
@@ -41,20 +56,101 @@ def broadcast_and_multiply(tensor, multiplier):
|
||||
return result
|
||||
|
||||
|
||||
def add_bias(tensor, bias):
|
||||
if bias is None:
|
||||
return tensor
|
||||
# add batch dim
|
||||
bias = bias.unsqueeze(0)
|
||||
bias = torch.cat([bias] * tensor.size(0), dim=0)
|
||||
# Determine the number of dimensions required
|
||||
num_extra_dims = tensor.dim() - bias.dim()
|
||||
|
||||
# Unsqueezing the tensor to match the dimensionality
|
||||
for _ in range(num_extra_dims):
|
||||
bias = bias.unsqueeze(-1)
|
||||
|
||||
# we may need to swap -1 for -2
|
||||
if bias.size(1) != tensor.size(1):
|
||||
if len(bias.size()) == 3:
|
||||
bias = bias.permute(0, 2, 1)
|
||||
elif len(bias.size()) == 4:
|
||||
bias = bias.permute(0, 3, 1, 2)
|
||||
|
||||
# Multiplying the broadcasted tensor with the output tensor
|
||||
try:
|
||||
result = tensor + bias
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(tensor.size())
|
||||
print(bias.size())
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtractableModuleMixin:
|
||||
def extract_weight(
|
||||
self: Module,
|
||||
extract_mode: ExtractMode = "existing",
|
||||
extract_mode_param: Union[int, float] = None,
|
||||
):
|
||||
device = self.lora_down.weight.device
|
||||
weight_to_extract = self.org_module[0].weight
|
||||
if extract_mode == "existing":
|
||||
extract_mode = 'fixed'
|
||||
extract_mode_param = self.lora_dim
|
||||
|
||||
if self.org_module[0].__class__.__name__ in CONV_MODULES:
|
||||
# do conv extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_conv(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device
|
||||
)
|
||||
|
||||
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
|
||||
# do linear extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_linear(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
|
||||
|
||||
self.lora_dim = new_dim
|
||||
|
||||
# inject weights into the param
|
||||
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
|
||||
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
|
||||
|
||||
# copy bias if we have one and are using them
|
||||
if self.org_module[0].bias is not None and self.lora_up.bias is not None:
|
||||
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
|
||||
|
||||
# set up alphas
|
||||
self.alpha = (self.alpha * 0) + down_weight.shape[0]
|
||||
self.scale = self.alpha / self.lora_dim
|
||||
|
||||
# assign them
|
||||
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
# scaler is a parameter update the value with 1.0
|
||||
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
|
||||
|
||||
|
||||
class ToolkitModuleMixin:
|
||||
def __init__(
|
||||
self: Module,
|
||||
*args,
|
||||
network: Network,
|
||||
call_super_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.network_ref: weakref.ref = weakref.ref(network)
|
||||
self.is_checkpointing = False
|
||||
# self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||
|
||||
def _call_forward(self: Module, x):
|
||||
@@ -100,11 +196,40 @@ class ToolkitModuleMixin:
|
||||
|
||||
return lx * scale
|
||||
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def lorm_forward(self: Network, x, *args, **kwargs):
|
||||
network: Network = self.network_ref()
|
||||
if not network.is_active:
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
if network.lorm_train_mode == 'local':
|
||||
# we are going to predict input with both and do a loss on them
|
||||
inputs = x.detach()
|
||||
with torch.no_grad():
|
||||
# get the local prediction
|
||||
target_pred = self.org_forward(inputs, *args, **kwargs).detach()
|
||||
with torch.set_grad_enabled(True):
|
||||
# make a prediction with the lorm
|
||||
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
|
||||
|
||||
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
|
||||
# backpropr
|
||||
local_loss.backward()
|
||||
|
||||
network.module_losses.append(local_loss.detach())
|
||||
# return the original as we dont want our trainer to affect ones down the line
|
||||
return target_pred
|
||||
|
||||
else:
|
||||
return self.lora_up(self.lora_down(x))
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
skip = False
|
||||
network = self.network_ref()
|
||||
network: Network = self.network_ref()
|
||||
if network.is_lorm:
|
||||
# we are doing lorm
|
||||
return self.lorm_forward(x, *args, **kwargs)
|
||||
|
||||
# skip if not active
|
||||
if not network.is_active:
|
||||
skip = True
|
||||
@@ -130,40 +255,9 @@ class ToolkitModuleMixin:
|
||||
if lora_output_batch_size != multiplier_batch_size:
|
||||
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
# multiplier = 1.0
|
||||
|
||||
if self.network_ref().is_normalizing:
|
||||
with torch.no_grad():
|
||||
|
||||
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
|
||||
if isinstance(multiplier, torch.Tensor):
|
||||
norm_multiplier = multiplier.clone().detach() * 10
|
||||
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
|
||||
else:
|
||||
norm_multiplier = multiplier
|
||||
|
||||
# get a dim array from orig forward that had index of all dimensions except the batch and channel
|
||||
|
||||
# Calculate the target magnitude for the combined output
|
||||
orig_max = torch.max(torch.abs(org_forwarded))
|
||||
|
||||
# Calculate the additional increase in magnitude that lora_output would introduce
|
||||
potential_max_increase = torch.max(
|
||||
torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
|
||||
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
|
||||
# Calculate the scaling factor for the lora_output
|
||||
# to ensure that the potential increase in magnitude doesn't change the original max
|
||||
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
|
||||
normalize_scaler = normalize_scaler.detach()
|
||||
|
||||
# save the scaler so it can be applied later
|
||||
self.normalize_scaler = normalize_scaler.clone().detach()
|
||||
|
||||
lora_output = lora_output * normalize_scaler
|
||||
|
||||
return org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
return x
|
||||
|
||||
def enable_gradient_checkpointing(self: Module):
|
||||
self.is_checkpointing = True
|
||||
@@ -171,40 +265,6 @@ class ToolkitModuleMixin:
|
||||
def disable_gradient_checkpointing(self: Module):
|
||||
self.is_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0):
|
||||
"""
|
||||
Applied the previous normalization calculation to the module.
|
||||
This must be called before saving or normalization will be lost.
|
||||
It is probably best to call after each batch as well.
|
||||
We just scale the up down weights to match this vector
|
||||
:return:
|
||||
"""
|
||||
# get state dict
|
||||
state_dict = self.state_dict()
|
||||
dtype = state_dict['lora_up.weight'].dtype
|
||||
device = state_dict['lora_up.weight'].device
|
||||
|
||||
# todo should we do this at fp32?
|
||||
if isinstance(self.normalize_scaler, torch.Tensor):
|
||||
scaler = self.normalize_scaler.clone().detach()
|
||||
else:
|
||||
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
|
||||
|
||||
total_module_scale = scaler / target_normalize_scaler
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to(device, dtype=dtype)
|
||||
|
||||
# apply the scaler to the up and down weights
|
||||
for key in state_dict.keys():
|
||||
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
|
||||
# do it inplace do params are updated
|
||||
state_dict[key] *= up_down_scale
|
||||
|
||||
# reset the normalization scaler
|
||||
self.normalize_scaler = target_normalize_scaler
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_out(self: Module, merge_out_weight=1.0):
|
||||
# make sure it is positive
|
||||
@@ -251,6 +311,23 @@ class ToolkitModuleMixin:
|
||||
org_sd["weight"] = weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
|
||||
# outputs the same. It is basically a LoRA but with the original module removed
|
||||
|
||||
# if a state dict is passed, use those weights instead of extracting
|
||||
# todo load from state dict
|
||||
network: Network = self.network_ref()
|
||||
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
|
||||
|
||||
extract_mode = lorm_config.extract_mode
|
||||
extract_mode_param = lorm_config.extract_mode_param
|
||||
parameter_threshold = lorm_config.parameter_threshold
|
||||
self.extract_weight(
|
||||
extract_mode=extract_mode,
|
||||
extract_mode_param=extract_mode_param
|
||||
)
|
||||
|
||||
|
||||
class ToolkitNetworkMixin:
|
||||
def __init__(
|
||||
@@ -260,6 +337,8 @@ class ToolkitNetworkMixin:
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
network_config: Optional[NetworkConfig] = None,
|
||||
is_lorm=False,
|
||||
**kwargs
|
||||
):
|
||||
self.train_text_encoder = train_text_encoder
|
||||
@@ -267,11 +346,14 @@ class ToolkitNetworkMixin:
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self._is_normalizing: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
# super().__init__(*args, **kwargs)
|
||||
self.is_lorm = is_lorm
|
||||
self.network_config: NetworkConfig = network_config
|
||||
self.module_losses: List[torch.Tensor] = []
|
||||
self.lorm_train_mode: Literal['local', None] = None
|
||||
self.can_merge_in = not is_lorm
|
||||
|
||||
def get_keymap(self: Network):
|
||||
if self.is_sdxl:
|
||||
@@ -443,28 +525,41 @@ class ToolkitNetworkMixin:
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
@property
|
||||
def is_normalizing(self: Network) -> bool:
|
||||
return self._is_normalizing
|
||||
|
||||
@is_normalizing.setter
|
||||
def is_normalizing(self: Network, value: bool):
|
||||
self._is_normalizing = value
|
||||
# for module in self.get_all_modules():
|
||||
# module.is_normalizing = self._is_normalizing
|
||||
|
||||
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
module.apply_stored_normalizer(target_normalize_scaler)
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
def merge_out(self, merge_weight=1.0):
|
||||
def merge_out(self: Network, merge_weight=1.0):
|
||||
if not self.is_merged_in:
|
||||
return
|
||||
self.is_merged_in = False
|
||||
for module in self.get_all_modules():
|
||||
module.merge_out(merge_weight)
|
||||
|
||||
def extract_weight(
|
||||
self: Network,
|
||||
extract_mode: ExtractMode = "existing",
|
||||
extract_mode_param: Union[int, float] = None,
|
||||
):
|
||||
if extract_mode_param is None:
|
||||
raise ValueError("extract_mode_param must be set")
|
||||
for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
|
||||
module.extract_weight(
|
||||
extract_mode=extract_mode,
|
||||
extract_mode_param=extract_mode_param
|
||||
)
|
||||
|
||||
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
|
||||
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
|
||||
module.setup_lorm(state_dict=state_dict)
|
||||
|
||||
def calculate_lorem_parameter_reduction(self):
|
||||
params_reduced = 0
|
||||
for module in self.get_all_modules():
|
||||
num_orig_module_params = count_parameters(module.org_module[0])
|
||||
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
|
||||
params_reduced += (num_orig_module_params - num_lorem_params)
|
||||
|
||||
return params_reduced
|
||||
|
||||
|
||||
Reference in New Issue
Block a user