Improved lorm extraction and training

This commit is contained in:
Jaret Burkett
2023-10-28 08:21:59 -06:00
parent 0a79ac9604
commit 6f3e0d5af2
10 changed files with 559 additions and 196 deletions

View File

@@ -1,17 +1,23 @@
import json
import os
from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
import torch
from torch import nn
import weakref
from tqdm import tqdm
from toolkit.config_modules import NetworkConfig
from toolkit.lorm import extract_conv, extract_linear, count_parameters
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
if TYPE_CHECKING:
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
Module = Union['LoConSpecialModule', 'LoRAModule']
@@ -26,6 +32,15 @@ CONV_MODULES = [
'LoRACompatibleConv'
]
ExtractMode = Union[
'existing'
'fixed',
'threshold',
'ratio',
'quantile',
'percentage'
]
def broadcast_and_multiply(tensor, multiplier):
# Determine the number of dimensions required
@@ -41,20 +56,101 @@ def broadcast_and_multiply(tensor, multiplier):
return result
def add_bias(tensor, bias):
if bias is None:
return tensor
# add batch dim
bias = bias.unsqueeze(0)
bias = torch.cat([bias] * tensor.size(0), dim=0)
# Determine the number of dimensions required
num_extra_dims = tensor.dim() - bias.dim()
# Unsqueezing the tensor to match the dimensionality
for _ in range(num_extra_dims):
bias = bias.unsqueeze(-1)
# we may need to swap -1 for -2
if bias.size(1) != tensor.size(1):
if len(bias.size()) == 3:
bias = bias.permute(0, 2, 1)
elif len(bias.size()) == 4:
bias = bias.permute(0, 3, 1, 2)
# Multiplying the broadcasted tensor with the output tensor
try:
result = tensor + bias
except RuntimeError as e:
print(e)
print(tensor.size())
print(bias.size())
raise e
return result
class ExtractableModuleMixin:
def extract_weight(
self: Module,
extract_mode: ExtractMode = "existing",
extract_mode_param: Union[int, float] = None,
):
device = self.lora_down.weight.device
weight_to_extract = self.org_module[0].weight
if extract_mode == "existing":
extract_mode = 'fixed'
extract_mode_param = self.lora_dim
if self.org_module[0].__class__.__name__ in CONV_MODULES:
# do conv extraction
down_weight, up_weight, new_dim, diff = extract_conv(
weight=weight_to_extract.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=device
)
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
# do linear extraction
down_weight, up_weight, new_dim, diff = extract_linear(
weight=weight_to_extract.clone().detach().float(),
mode=extract_mode,
mode_param=extract_mode_param,
device=device,
)
else:
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
self.lora_dim = new_dim
# inject weights into the param
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
# copy bias if we have one and are using them
if self.org_module[0].bias is not None and self.lora_up.bias is not None:
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
# set up alphas
self.alpha = (self.alpha * 0) + down_weight.shape[0]
self.scale = self.alpha / self.lora_dim
# assign them
# handle trainable scaler method locon does
if hasattr(self, 'scalar'):
# scaler is a parameter update the value with 1.0
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
network: Network,
call_super_init: bool = True,
**kwargs
):
if call_super_init:
super().__init__(*args, **kwargs)
self.network_ref: weakref.ref = weakref.ref(network)
self.is_checkpointing = False
# self.is_normalizing = False
self.normalize_scaler = 1.0
self._multiplier: Union[float, list, torch.Tensor] = None
def _call_forward(self: Module, x):
@@ -100,11 +196,40 @@ class ToolkitModuleMixin:
return lx * scale
# this may get an additional positional arg or not
def lorm_forward(self: Network, x, *args, **kwargs):
network: Network = self.network_ref()
if not network.is_active:
return self.org_forward(x, *args, **kwargs)
if network.lorm_train_mode == 'local':
# we are going to predict input with both and do a loss on them
inputs = x.detach()
with torch.no_grad():
# get the local prediction
target_pred = self.org_forward(inputs, *args, **kwargs).detach()
with torch.set_grad_enabled(True):
# make a prediction with the lorm
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
# backpropr
local_loss.backward()
network.module_losses.append(local_loss.detach())
# return the original as we dont want our trainer to affect ones down the line
return target_pred
else:
return self.lora_up(self.lora_down(x))
def forward(self: Module, x, *args, **kwargs):
skip = False
network = self.network_ref()
network: Network = self.network_ref()
if network.is_lorm:
# we are doing lorm
return self.lorm_forward(x, *args, **kwargs)
# skip if not active
if not network.is_active:
skip = True
@@ -130,40 +255,9 @@ class ToolkitModuleMixin:
if lora_output_batch_size != multiplier_batch_size:
num_interleaves = lora_output_batch_size // multiplier_batch_size
multiplier = multiplier.repeat_interleave(num_interleaves)
# multiplier = 1.0
if self.network_ref().is_normalizing:
with torch.no_grad():
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
if isinstance(multiplier, torch.Tensor):
norm_multiplier = multiplier.clone().detach() * 10
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
else:
norm_multiplier = multiplier
# get a dim array from orig forward that had index of all dimensions except the batch and channel
# Calculate the target magnitude for the combined output
orig_max = torch.max(torch.abs(org_forwarded))
# Calculate the additional increase in magnitude that lora_output would introduce
potential_max_increase = torch.max(
torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
epsilon = 1e-6 # Small constant to avoid division by zero
# Calculate the scaling factor for the lora_output
# to ensure that the potential increase in magnitude doesn't change the original max
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
normalize_scaler = normalize_scaler.detach()
# save the scaler so it can be applied later
self.normalize_scaler = normalize_scaler.clone().detach()
lora_output = lora_output * normalize_scaler
return org_forwarded + broadcast_and_multiply(lora_output, multiplier)
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
return x
def enable_gradient_checkpointing(self: Module):
self.is_checkpointing = True
@@ -171,40 +265,6 @@ class ToolkitModuleMixin:
def disable_gradient_checkpointing(self: Module):
self.is_checkpointing = False
@torch.no_grad()
def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0):
"""
Applied the previous normalization calculation to the module.
This must be called before saving or normalization will be lost.
It is probably best to call after each batch as well.
We just scale the up down weights to match this vector
:return:
"""
# get state dict
state_dict = self.state_dict()
dtype = state_dict['lora_up.weight'].dtype
device = state_dict['lora_up.weight'].device
# todo should we do this at fp32?
if isinstance(self.normalize_scaler, torch.Tensor):
scaler = self.normalize_scaler.clone().detach()
else:
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
total_module_scale = scaler / target_normalize_scaler
num_modules_layers = 2 # up and down
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
.to(device, dtype=dtype)
# apply the scaler to the up and down weights
for key in state_dict.keys():
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
# do it inplace do params are updated
state_dict[key] *= up_down_scale
# reset the normalization scaler
self.normalize_scaler = target_normalize_scaler
@torch.no_grad()
def merge_out(self: Module, merge_out_weight=1.0):
# make sure it is positive
@@ -251,6 +311,23 @@ class ToolkitModuleMixin:
org_sd["weight"] = weight.to(orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
# outputs the same. It is basically a LoRA but with the original module removed
# if a state dict is passed, use those weights instead of extracting
# todo load from state dict
network: Network = self.network_ref()
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
extract_mode = lorm_config.extract_mode
extract_mode_param = lorm_config.extract_mode_param
parameter_threshold = lorm_config.parameter_threshold
self.extract_weight(
extract_mode=extract_mode,
extract_mode_param=extract_mode_param
)
class ToolkitNetworkMixin:
def __init__(
@@ -260,6 +337,8 @@ class ToolkitNetworkMixin:
train_unet: Optional[bool] = True,
is_sdxl=False,
is_v2=False,
network_config: Optional[NetworkConfig] = None,
is_lorm=False,
**kwargs
):
self.train_text_encoder = train_text_encoder
@@ -267,11 +346,14 @@ class ToolkitNetworkMixin:
self.is_checkpointing = False
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_merged_in = False
# super().__init__(*args, **kwargs)
self.is_lorm = is_lorm
self.network_config: NetworkConfig = network_config
self.module_losses: List[torch.Tensor] = []
self.lorm_train_mode: Literal['local', None] = None
self.can_merge_in = not is_lorm
def get_keymap(self: Network):
if self.is_sdxl:
@@ -443,28 +525,41 @@ class ToolkitNetworkMixin:
self.is_checkpointing = False
self._update_checkpointing()
@property
def is_normalizing(self: Network) -> bool:
return self._is_normalizing
@is_normalizing.setter
def is_normalizing(self: Network, value: bool):
self._is_normalizing = value
# for module in self.get_all_modules():
# module.is_normalizing = self._is_normalizing
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
for module in self.get_all_modules():
module.apply_stored_normalizer(target_normalize_scaler)
def merge_in(self, merge_weight=1.0):
self.is_merged_in = True
for module in self.get_all_modules():
module.merge_in(merge_weight)
def merge_out(self, merge_weight=1.0):
def merge_out(self: Network, merge_weight=1.0):
if not self.is_merged_in:
return
self.is_merged_in = False
for module in self.get_all_modules():
module.merge_out(merge_weight)
def extract_weight(
self: Network,
extract_mode: ExtractMode = "existing",
extract_mode_param: Union[int, float] = None,
):
if extract_mode_param is None:
raise ValueError("extract_mode_param must be set")
for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
module.extract_weight(
extract_mode=extract_mode,
extract_mode_param=extract_mode_param
)
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
module.setup_lorm(state_dict=state_dict)
def calculate_lorem_parameter_reduction(self):
params_reduced = 0
for module in self.get_all_modules():
num_orig_module_params = count_parameters(module.org_module[0])
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
params_reduced += (num_orig_module_params - num_lorem_params)
return params_reduced