mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improved lorm extraction and training
This commit is contained in:
102
extensions_built_in/advanced_generator/PureLoraGenerator.py
Normal file
102
extensions_built_in/advanced_generator/PureLoraGenerator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
|
||||
from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
import gc
|
||||
import torch
|
||||
from jobs.process import BaseExtensionProcess
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class PureLoraGenerator(BaseExtensionProcess):
|
||||
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.output_folder = self.get_conf('output_folder', required=True)
|
||||
self.device = self.get_conf('device', 'cuda')
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.model_config = ModelConfig(**self.get_conf('model', required=True))
|
||||
self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
|
||||
self.dtype = self.get_conf('dtype', 'float16')
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
lorm_config = self.get_conf('lorm', None)
|
||||
self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
|
||||
|
||||
self.device_state_preset = get_train_sd_device_state_preset(
|
||||
device=torch.device(self.device),
|
||||
)
|
||||
|
||||
self.progress_bar = None
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
print("Loading model...")
|
||||
with torch.no_grad():
|
||||
self.sd.load_model()
|
||||
self.sd.unet.eval()
|
||||
self.sd.unet.to(self.device_torch)
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for te in self.sd.text_encoder:
|
||||
te.eval()
|
||||
te.to(self.device_torch)
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
self.sd.to(self.device_torch)
|
||||
|
||||
print(f"Converting to LoRM UNet")
|
||||
# replace the unet with LoRMUnet
|
||||
convert_diffusers_unet_to_lorm(
|
||||
self.sd.unet,
|
||||
config=self.lorm_config,
|
||||
)
|
||||
|
||||
sample_folder = os.path.join(self.output_folder)
|
||||
gen_img_config_list = []
|
||||
|
||||
sample_config = self.generate_config
|
||||
start_seed = sample_config.seed
|
||||
current_seed = start_seed
|
||||
for i in range(len(sample_config.prompts)):
|
||||
if sample_config.walk_seed:
|
||||
current_seed = start_seed + i
|
||||
|
||||
filename = f"[time]_[count].{self.generate_config.ext}"
|
||||
output_path = os.path.join(sample_folder, filename)
|
||||
prompt = sample_config.prompts[i]
|
||||
extra_args = {}
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
prompt=prompt, # it will autoparse the prompt
|
||||
width=sample_config.width,
|
||||
height=sample_config.height,
|
||||
negative_prompt=sample_config.neg,
|
||||
seed=current_seed,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
guidance_rescale=sample_config.guidance_rescale,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
network_multiplier=sample_config.network_multiplier,
|
||||
output_path=output_path,
|
||||
output_ext=sample_config.ext,
|
||||
adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
|
||||
**extra_args
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
print("Done generating images")
|
||||
# cleanup
|
||||
del self.sd
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -19,7 +19,24 @@ class AdvancedReferenceGeneratorExtension(Extension):
|
||||
return ReferenceGenerator
|
||||
|
||||
|
||||
# This is for generic training (LoRA, Dreambooth, FineTuning)
|
||||
class PureLoraGenerator(Extension):
|
||||
# uid must be unique, it is how the extension is identified
|
||||
uid = "pure_lora_generator"
|
||||
|
||||
# name is the name of the extension for printing
|
||||
name = "Pure LoRA Generator"
|
||||
|
||||
# This is where your process class is loaded
|
||||
# keep your imports in here so they don't slow down the rest of the program
|
||||
@classmethod
|
||||
def get_process(cls):
|
||||
# import your process class here so it is only loaded when needed and return it
|
||||
from .PureLoraGenerator import PureLoraGenerator
|
||||
return PureLoraGenerator
|
||||
|
||||
|
||||
AI_TOOLKIT_EXTENSIONS = [
|
||||
# you can put a list of extensions here
|
||||
AdvancedReferenceGeneratorExtension,
|
||||
AdvancedReferenceGeneratorExtension, PureLoraGenerator
|
||||
]
|
||||
|
||||
@@ -32,7 +32,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.train_config.inverted_mask_prior:
|
||||
self.do_prior_prediction = True
|
||||
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
@@ -193,6 +192,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.network.is_active = was_network_active
|
||||
return prior_pred
|
||||
|
||||
def before_unet_predict(self):
|
||||
pass
|
||||
|
||||
def after_unet_predict(self):
|
||||
pass
|
||||
|
||||
def end_of_training_loop(self):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
@@ -331,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_images_list = [adapter_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
@@ -366,7 +373,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
if has_adapter_img and (
|
||||
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
@@ -406,8 +414,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
|
||||
|
||||
self.before_unet_predict()
|
||||
with self.timer('predict_unet'):
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
@@ -416,6 +423,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
@@ -442,7 +450,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm)
|
||||
# flush()
|
||||
# flush()
|
||||
|
||||
with self.timer('optimizer_step'):
|
||||
# apply gradients
|
||||
@@ -460,4 +468,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
{'loss': loss.item()}
|
||||
)
|
||||
|
||||
self.end_of_training_loop()
|
||||
|
||||
return loss_dict
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter
|
||||
from safetensors.torch import save_file, load_file
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
@@ -18,7 +19,8 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
|
||||
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -128,6 +130,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
self.do_lorm = self.get_conf('do_lorm', False)
|
||||
self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio')
|
||||
self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25)
|
||||
# 'ratio', 0.25)
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
@@ -300,9 +305,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if we are doing embedding training as well, add that
|
||||
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||
@@ -427,6 +429,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def load_lorm(self):
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
# hacky way to reload weights for now
|
||||
# todo, do this
|
||||
state_dict = load_file(latest_save_path, device=self.device)
|
||||
self.sd.unet.load_state_dict(state_dict)
|
||||
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
|
||||
# sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
|
||||
@@ -610,7 +627,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
batch.control_tensor = double_up_tensor(batch.control_tensor)
|
||||
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
@@ -712,15 +728,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.do_lorm:
|
||||
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
|
||||
for module in train_modules:
|
||||
p = module.parameters()
|
||||
for param in p:
|
||||
param.requires_grad_(True)
|
||||
params.append(param)
|
||||
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -783,14 +790,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
network_kwargs = {}
|
||||
is_lycoris = False
|
||||
is_lorm = self.network_config.type.lower() == 'lorm'
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
NetworkClass = LoRASpecialNetwork
|
||||
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
|
||||
NetworkClass = LycorisSpecialNetwork
|
||||
is_lycoris = True
|
||||
|
||||
if is_lorm:
|
||||
network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains
|
||||
network_kwargs['parameter_threshold'] = lorm_parameter_threshold
|
||||
network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE
|
||||
|
||||
# if is_lycoris:
|
||||
# preset = PRESET['full']
|
||||
# NetworkClass.apply_preset(preset)
|
||||
@@ -810,6 +823,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
use_bias=is_lorm,
|
||||
is_lorm=is_lorm,
|
||||
network_config=self.network_config,
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
@@ -824,6 +841,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.train_config.train_unet
|
||||
)
|
||||
|
||||
if is_lorm:
|
||||
self.network.is_lorm = True
|
||||
# make sure it is on the right device
|
||||
self.sd.unet.to(self.sd.device, dtype=dtype)
|
||||
original_unet_param_count = count_parameters(self.sd.unet)
|
||||
self.network.setup_lorm()
|
||||
new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction()
|
||||
|
||||
print_lorm_extract_details(
|
||||
start_num_params=original_unet_param_count,
|
||||
end_num_params=new_unet_param_count,
|
||||
num_replaced=len(self.network.get_all_modules()),
|
||||
)
|
||||
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
flush()
|
||||
|
||||
@@ -846,9 +877,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
|
||||
# set the network to normalize if we are
|
||||
self.network.is_normalizing = self.network_config.normalize
|
||||
|
||||
lora_name = self.name
|
||||
# need to adapt name so they are not mixed up
|
||||
if self.named_lora:
|
||||
@@ -915,7 +943,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
|
||||
# params = self.get_params()
|
||||
if len(params) == 0:
|
||||
# will only return savable weights and ones with grad
|
||||
@@ -1050,9 +1077,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# turn on normalization if we are using it and it is not on
|
||||
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
|
||||
self.network.is_normalizing = True
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
@@ -1078,11 +1102,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
with self.timer('apply_normalizer'):
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
with self.timer('batch_cleanup'):
|
||||
|
||||
@@ -40,7 +40,48 @@ class SampleConfig:
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon']
|
||||
class LormModuleSettingsConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.contains: str = kwargs.get('contains', '4nt$3')
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
# min num parameters to attach to
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
|
||||
|
||||
class LoRMConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.extract_mode: str = kwargs.get('extract_mode', 'ratio')
|
||||
self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25)
|
||||
self.parameter_threshold: int = kwargs.get('parameter_threshold', 0)
|
||||
module_settings = kwargs.get('module_settings', [])
|
||||
default_module_settings = {
|
||||
'extract_mode': self.extract_mode,
|
||||
'extract_mode_param': self.extract_mode_param,
|
||||
'parameter_threshold': self.parameter_threshold,
|
||||
}
|
||||
module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings]
|
||||
self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for
|
||||
module_setting in module_settings]
|
||||
|
||||
def get_config_for_module(self, block_name):
|
||||
for setting in self.module_settings:
|
||||
contain_pieces = setting.contains.split('|')
|
||||
if all(contain_piece in block_name for contain_piece in contain_pieces):
|
||||
return setting
|
||||
# try replacing the . with _
|
||||
contain_pieces = setting.contains.replace('.', '_').split('|')
|
||||
if all(contain_piece in block_name for contain_piece in contain_pieces):
|
||||
return setting
|
||||
# do default
|
||||
return LormModuleSettingsConfig(**{
|
||||
'extract_mode': self.extract_mode,
|
||||
'extract_mode_param': self.extract_mode_param,
|
||||
'parameter_threshold': self.parameter_threshold,
|
||||
})
|
||||
|
||||
|
||||
NetworkType = Literal['lora', 'locon', 'lorm']
|
||||
|
||||
|
||||
class NetworkConfig:
|
||||
@@ -58,12 +99,22 @@ class NetworkConfig:
|
||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
self.normalize = kwargs.get('normalize', False)
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
self.lorm_config: Union[LoRMConfig, None] = None
|
||||
lorm = kwargs.get('lorm', None)
|
||||
if lorm is not None:
|
||||
self.lorm_config: LoRMConfig = LoRMConfig(**lorm)
|
||||
|
||||
if self.type == 'lorm':
|
||||
# set linear to arbitrary values so it makes them
|
||||
self.linear = 4
|
||||
self.rank = 4
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
|
||||
@@ -90,6 +141,7 @@ class EmbeddingConfig:
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
@@ -138,7 +190,8 @@ class TrainConfig:
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target',
|
||||
'noise') # noise, source, unaugmented, differential_noise
|
||||
|
||||
# When a mask is passed in a dataset, and this is true,
|
||||
# we will predict noise without a the LoRa network and use the prediction as a target for
|
||||
@@ -151,7 +204,6 @@ class TrainConfig:
|
||||
self.match_adapter_chance = 1.0
|
||||
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.name_or_path: str = kwargs.get('name_or_path', None)
|
||||
@@ -216,7 +268,7 @@ class SliderConfig:
|
||||
self.prompt_file: str = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors: str = kwargs.get('prompt_tensors', None)
|
||||
self.batch_full_slide: bool = kwargs.get('batch_full_slide', True)
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.use_adapter: bool = kwargs.get('use_adapter', None) # depth
|
||||
self.adapter_img_dir = kwargs.get('adapter_img_dir', None)
|
||||
self.high_ram = kwargs.get('high_ram', False)
|
||||
|
||||
@@ -267,9 +319,11 @@ class DatasetConfig:
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_path: str = kwargs.get('mask_path',
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
self.poi: Union[str, None] = kwargs.get('poi',
|
||||
None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
@@ -525,4 +579,4 @@ class GenerateImageConfig:
|
||||
unconditional_prompt_embeds: Optional[PromptEmbeds] = None,
|
||||
):
|
||||
# this is called after prompt embeds are encoded. We can override them in the future here
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,9 @@ from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
|
||||
from .config_modules import NetworkConfig
|
||||
from .lorm import count_parameters
|
||||
from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||
from .paths import SD_SCRIPTS_ROOT
|
||||
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
@@ -30,7 +32,7 @@ CONV_MODULES = [
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
@@ -46,13 +48,17 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
parent=None,
|
||||
use_bias: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__(network=network)
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.scalar = torch.tensor(1.0)
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
|
||||
if org_module.__class__.__name__ in CONV_MODULES:
|
||||
in_dim = org_module.in_channels
|
||||
@@ -73,10 +79,10 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
@@ -95,8 +101,6 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.is_checkpointing = False
|
||||
self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
@@ -143,6 +147,13 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
parameter_threshold: float = 0.0,
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
LoRA network: すごく引数が多いが、パターンは以下の通り
|
||||
@@ -154,7 +165,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
"""
|
||||
# call the parent of the parent we are replacing (LoRANetwork) init
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
ToolkitNetworkMixin.__init__(
|
||||
self,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
is_sdxl=is_sdxl,
|
||||
is_v2=is_v2,
|
||||
is_lorm=is_lorm,
|
||||
**kwargs
|
||||
)
|
||||
if ignore_if_contains is None:
|
||||
ignore_if_contains = []
|
||||
self.ignore_if_contains = ignore_if_contains
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
@@ -165,13 +187,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self._is_normalizing: bool = False
|
||||
self.torch_multiplier = None
|
||||
# triggers the state updates
|
||||
self.multiplier = multiplier
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
@@ -217,7 +237,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
skip = False
|
||||
if any([word in child_name for word in self.ignore_if_contains]):
|
||||
skip = True
|
||||
|
||||
# see if it is over threshold
|
||||
if count_parameters(child_module) < parameter_threshold:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
@@ -265,6 +293,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
module_dropout=module_dropout,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
@@ -295,9 +324,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
target_modules = target_lin_modules
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
target_modules += target_conv_modules
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
|
||||
@@ -6,6 +6,8 @@ from diffusers import UNet2DConditionModel
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import LoRMConfig
|
||||
|
||||
conv = nn.Conv2d
|
||||
lin = nn.Linear
|
||||
_size_2_t = Union[int, Tuple[int, int]]
|
||||
@@ -29,12 +31,13 @@ CONV_MODULES = [
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
# "BasicTransformerBlock",
|
||||
# "ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
]
|
||||
|
||||
LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE
|
||||
|
||||
UNET_TARGET_REPLACE_NAME = [
|
||||
"conv_in",
|
||||
"conv_out",
|
||||
@@ -279,13 +282,38 @@ def compute_optimal_bias(original_module, linear_down, linear_up, X):
|
||||
return optimal_bias
|
||||
|
||||
|
||||
def format_with_commas(n):
|
||||
return f"{n:,}"
|
||||
|
||||
|
||||
def print_lorm_extract_details(
|
||||
start_num_params: int,
|
||||
end_num_params: int,
|
||||
num_replaced: int,
|
||||
):
|
||||
start_formatted = format_with_commas(start_num_params)
|
||||
end_formatted = format_with_commas(end_num_params)
|
||||
num_replaced_formatted = format_with_commas(num_replaced)
|
||||
|
||||
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
||||
|
||||
print(f"Convert UNet result:")
|
||||
print(f" - converted: {num_replaced:>{width},} modules")
|
||||
print(f" - start: {start_num_params:>{width},} params")
|
||||
print(f" - end: {end_num_params:>{width},} params")
|
||||
|
||||
|
||||
lorm_ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
|
||||
lorm_parameter_threshold = 1000000
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffusers_unet_to_lorm(
|
||||
unet: UNet2DConditionModel,
|
||||
extract_mode: ExtractMode = "percentile",
|
||||
mode_param: Union[int, float] = 0.5,
|
||||
parameter_threshold: int = 500000,
|
||||
# parameter_threshold: int = 1500000
|
||||
config: LoRMConfig,
|
||||
):
|
||||
print('Converting UNet to LoRM UNet')
|
||||
start_num_params = count_parameters(unet)
|
||||
@@ -299,8 +327,6 @@ def convert_diffusers_unet_to_lorm(
|
||||
ignore_if_contains = [
|
||||
'proj_out', 'proj_in',
|
||||
]
|
||||
def format_with_commas(n):
|
||||
return f"{n:,}"
|
||||
|
||||
for name, module in named_modules:
|
||||
module_name = module.__class__.__name__
|
||||
@@ -311,6 +337,13 @@ def convert_diffusers_unet_to_lorm(
|
||||
combined_name = combined_name = f"{name}.{child_name}"
|
||||
# if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None:
|
||||
# pass
|
||||
|
||||
lorm_config = config.get_config_for_module(combined_name)
|
||||
|
||||
extract_mode = lorm_config.extract_mode
|
||||
extract_mode_param = lorm_config.extract_mode_param
|
||||
parameter_threshold = lorm_config.parameter_threshold
|
||||
|
||||
if any([word in child_name for word in ignore_if_contains]):
|
||||
pass
|
||||
|
||||
@@ -322,7 +355,7 @@ def convert_diffusers_unet_to_lorm(
|
||||
down_weight, up_weight, lora_dim, diff = extract_linear(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=mode_param,
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
@@ -362,7 +395,7 @@ def convert_diffusers_unet_to_lorm(
|
||||
down_weight, up_weight, lora_dim, diff = extract_conv(
|
||||
weight=child_module.weight.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=mode_param,
|
||||
mode_param=extract_mode_param,
|
||||
device=child_module.weight.device,
|
||||
)
|
||||
down_weight = down_weight.to(dtype=dtype)
|
||||
@@ -395,30 +428,25 @@ def convert_diffusers_unet_to_lorm(
|
||||
replace_module_by_path(unet, combined_name, new_module)
|
||||
converted_modules.append(new_module)
|
||||
num_replaced += 1
|
||||
layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
||||
layer_names_replaced.append(
|
||||
f"{combined_name} - {format_with_commas(count_parameters(child_module))}")
|
||||
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
end_num_params = count_parameters(unet)
|
||||
|
||||
start_formatted = format_with_commas(start_num_params)
|
||||
end_formatted = format_with_commas(end_num_params)
|
||||
num_replaced_formatted = format_with_commas(num_replaced)
|
||||
|
||||
width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted))
|
||||
|
||||
def sorting_key(s):
|
||||
# Extract the number part, remove commas, and convert to integer
|
||||
return int(s.split("-")[1].strip().replace(",", ""))
|
||||
|
||||
sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True)
|
||||
|
||||
for layer_name in sorted_layer_names_replaced:
|
||||
print(layer_name)
|
||||
|
||||
print(f"Convert UNet result:")
|
||||
print(f" - converted: {num_replaced:>{width},} modules")
|
||||
print(f" - start: {start_num_params:>{width},} params")
|
||||
print(f" - end: {end_num_params:>{width},} params")
|
||||
print_lorm_extract_details(
|
||||
start_num_params=start_num_params,
|
||||
end_num_params=end_num_params,
|
||||
num_replaced=num_replaced,
|
||||
)
|
||||
|
||||
return converted_modules
|
||||
|
||||
@@ -8,20 +8,19 @@ from lycoris.modules.glora import GLoRAModule
|
||||
from torch import nn
|
||||
from transformers import CLIPTextModel
|
||||
from torch.nn import functional as F
|
||||
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin
|
||||
from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin
|
||||
|
||||
# diffusers specific stuff
|
||||
LINEAR_MODULES = [
|
||||
'Linear',
|
||||
'LoRACompatibleLinear'
|
||||
# 'GroupNorm',
|
||||
]
|
||||
CONV_MODULES = [
|
||||
'Conv2d',
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name, org_module: nn.Module,
|
||||
@@ -30,18 +29,20 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
dropout=0., rank_dropout=0., module_dropout=0.,
|
||||
use_cp=False,
|
||||
network: 'LycorisSpecialNetwork' = None,
|
||||
parent=None,
|
||||
use_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
# call super of super
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
# call super of
|
||||
super().__init__(call_super_init=False, network=network)
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.cp = False
|
||||
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
|
||||
self.scalar = nn.Parameter(torch.tensor(0.0))
|
||||
orig_module_name = org_module.__class__.__name__
|
||||
@@ -61,7 +62,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
self.cp = True
|
||||
else:
|
||||
self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False)
|
||||
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
||||
self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias)
|
||||
elif orig_module_name in LINEAR_MODULES:
|
||||
self.isconv = False
|
||||
self.down_op = F.linear
|
||||
@@ -74,7 +75,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
||||
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
||||
self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.shape = org_module.weight.shape
|
||||
@@ -159,10 +160,16 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
train_text_encoder: bool = True,
|
||||
use_text_encoder_1: bool = True,
|
||||
use_text_encoder_2: bool = True,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# call ToolkitNetworkMixin super
|
||||
super().__init__(
|
||||
ToolkitNetworkMixin.__init__(
|
||||
self,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
is_lorm=is_lorm,
|
||||
**kwargs
|
||||
)
|
||||
# call the parent of the parent LycorisNetwork
|
||||
@@ -217,7 +224,6 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
loras = []
|
||||
# remove this
|
||||
named_modules = root_module.named_modules()
|
||||
modules = root_module.modules()
|
||||
# add a few to tthe generator
|
||||
|
||||
for name, module in named_modules:
|
||||
@@ -241,6 +247,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
@@ -253,6 +260,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
@@ -263,6 +271,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@@ -285,6 +294,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_cp,
|
||||
parent=module,
|
||||
network=self,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif module.__class__.__name__ == 'Conv2d':
|
||||
@@ -297,6 +307,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
elif conv_lora_dim > 0:
|
||||
@@ -307,6 +318,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
|
||||
use_cp,
|
||||
network=self,
|
||||
parent=module,
|
||||
use_bias=use_bias,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
|
||||
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import weakref
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import NetworkConfig
|
||||
from toolkit.lorm import extract_conv, extract_linear, count_parameters
|
||||
from toolkit.metadata import add_model_hash_to_meta
|
||||
from toolkit.paths import KEYMAPS_ROOT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule
|
||||
from toolkit.lora_special import LoRASpecialNetwork, LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork']
|
||||
Module = Union['LoConSpecialModule', 'LoRAModule']
|
||||
@@ -26,6 +32,15 @@ CONV_MODULES = [
|
||||
'LoRACompatibleConv'
|
||||
]
|
||||
|
||||
ExtractMode = Union[
|
||||
'existing'
|
||||
'fixed',
|
||||
'threshold',
|
||||
'ratio',
|
||||
'quantile',
|
||||
'percentage'
|
||||
]
|
||||
|
||||
|
||||
def broadcast_and_multiply(tensor, multiplier):
|
||||
# Determine the number of dimensions required
|
||||
@@ -41,20 +56,101 @@ def broadcast_and_multiply(tensor, multiplier):
|
||||
return result
|
||||
|
||||
|
||||
def add_bias(tensor, bias):
|
||||
if bias is None:
|
||||
return tensor
|
||||
# add batch dim
|
||||
bias = bias.unsqueeze(0)
|
||||
bias = torch.cat([bias] * tensor.size(0), dim=0)
|
||||
# Determine the number of dimensions required
|
||||
num_extra_dims = tensor.dim() - bias.dim()
|
||||
|
||||
# Unsqueezing the tensor to match the dimensionality
|
||||
for _ in range(num_extra_dims):
|
||||
bias = bias.unsqueeze(-1)
|
||||
|
||||
# we may need to swap -1 for -2
|
||||
if bias.size(1) != tensor.size(1):
|
||||
if len(bias.size()) == 3:
|
||||
bias = bias.permute(0, 2, 1)
|
||||
elif len(bias.size()) == 4:
|
||||
bias = bias.permute(0, 3, 1, 2)
|
||||
|
||||
# Multiplying the broadcasted tensor with the output tensor
|
||||
try:
|
||||
result = tensor + bias
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(tensor.size())
|
||||
print(bias.size())
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExtractableModuleMixin:
|
||||
def extract_weight(
|
||||
self: Module,
|
||||
extract_mode: ExtractMode = "existing",
|
||||
extract_mode_param: Union[int, float] = None,
|
||||
):
|
||||
device = self.lora_down.weight.device
|
||||
weight_to_extract = self.org_module[0].weight
|
||||
if extract_mode == "existing":
|
||||
extract_mode = 'fixed'
|
||||
extract_mode_param = self.lora_dim
|
||||
|
||||
if self.org_module[0].__class__.__name__ in CONV_MODULES:
|
||||
# do conv extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_conv(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device
|
||||
)
|
||||
|
||||
elif self.org_module[0].__class__.__name__ in LINEAR_MODULES:
|
||||
# do linear extraction
|
||||
down_weight, up_weight, new_dim, diff = extract_linear(
|
||||
weight=weight_to_extract.clone().detach().float(),
|
||||
mode=extract_mode,
|
||||
mode_param=extract_mode_param,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}")
|
||||
|
||||
self.lora_dim = new_dim
|
||||
|
||||
# inject weights into the param
|
||||
self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach()
|
||||
self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach()
|
||||
|
||||
# copy bias if we have one and are using them
|
||||
if self.org_module[0].bias is not None and self.lora_up.bias is not None:
|
||||
self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach()
|
||||
|
||||
# set up alphas
|
||||
self.alpha = (self.alpha * 0) + down_weight.shape[0]
|
||||
self.scale = self.alpha / self.lora_dim
|
||||
|
||||
# assign them
|
||||
|
||||
# handle trainable scaler method locon does
|
||||
if hasattr(self, 'scalar'):
|
||||
# scaler is a parameter update the value with 1.0
|
||||
self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype)
|
||||
|
||||
|
||||
class ToolkitModuleMixin:
|
||||
def __init__(
|
||||
self: Module,
|
||||
*args,
|
||||
network: Network,
|
||||
call_super_init: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
if call_super_init:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.network_ref: weakref.ref = weakref.ref(network)
|
||||
self.is_checkpointing = False
|
||||
# self.is_normalizing = False
|
||||
self.normalize_scaler = 1.0
|
||||
self._multiplier: Union[float, list, torch.Tensor] = None
|
||||
|
||||
def _call_forward(self: Module, x):
|
||||
@@ -100,11 +196,40 @@ class ToolkitModuleMixin:
|
||||
|
||||
return lx * scale
|
||||
|
||||
# this may get an additional positional arg or not
|
||||
|
||||
def lorm_forward(self: Network, x, *args, **kwargs):
|
||||
network: Network = self.network_ref()
|
||||
if not network.is_active:
|
||||
return self.org_forward(x, *args, **kwargs)
|
||||
|
||||
if network.lorm_train_mode == 'local':
|
||||
# we are going to predict input with both and do a loss on them
|
||||
inputs = x.detach()
|
||||
with torch.no_grad():
|
||||
# get the local prediction
|
||||
target_pred = self.org_forward(inputs, *args, **kwargs).detach()
|
||||
with torch.set_grad_enabled(True):
|
||||
# make a prediction with the lorm
|
||||
lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True)))
|
||||
|
||||
local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float())
|
||||
# backpropr
|
||||
local_loss.backward()
|
||||
|
||||
network.module_losses.append(local_loss.detach())
|
||||
# return the original as we dont want our trainer to affect ones down the line
|
||||
return target_pred
|
||||
|
||||
else:
|
||||
return self.lora_up(self.lora_down(x))
|
||||
|
||||
def forward(self: Module, x, *args, **kwargs):
|
||||
skip = False
|
||||
network = self.network_ref()
|
||||
network: Network = self.network_ref()
|
||||
if network.is_lorm:
|
||||
# we are doing lorm
|
||||
return self.lorm_forward(x, *args, **kwargs)
|
||||
|
||||
# skip if not active
|
||||
if not network.is_active:
|
||||
skip = True
|
||||
@@ -130,40 +255,9 @@ class ToolkitModuleMixin:
|
||||
if lora_output_batch_size != multiplier_batch_size:
|
||||
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
# multiplier = 1.0
|
||||
|
||||
if self.network_ref().is_normalizing:
|
||||
with torch.no_grad():
|
||||
|
||||
# do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier
|
||||
if isinstance(multiplier, torch.Tensor):
|
||||
norm_multiplier = multiplier.clone().detach() * 10
|
||||
norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0)
|
||||
else:
|
||||
norm_multiplier = multiplier
|
||||
|
||||
# get a dim array from orig forward that had index of all dimensions except the batch and channel
|
||||
|
||||
# Calculate the target magnitude for the combined output
|
||||
orig_max = torch.max(torch.abs(org_forwarded))
|
||||
|
||||
# Calculate the additional increase in magnitude that lora_output would introduce
|
||||
potential_max_increase = torch.max(
|
||||
torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded))
|
||||
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
|
||||
# Calculate the scaling factor for the lora_output
|
||||
# to ensure that the potential increase in magnitude doesn't change the original max
|
||||
normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon)
|
||||
normalize_scaler = normalize_scaler.detach()
|
||||
|
||||
# save the scaler so it can be applied later
|
||||
self.normalize_scaler = normalize_scaler.clone().detach()
|
||||
|
||||
lora_output = lora_output * normalize_scaler
|
||||
|
||||
return org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
return x
|
||||
|
||||
def enable_gradient_checkpointing(self: Module):
|
||||
self.is_checkpointing = True
|
||||
@@ -171,40 +265,6 @@ class ToolkitModuleMixin:
|
||||
def disable_gradient_checkpointing(self: Module):
|
||||
self.is_checkpointing = False
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0):
|
||||
"""
|
||||
Applied the previous normalization calculation to the module.
|
||||
This must be called before saving or normalization will be lost.
|
||||
It is probably best to call after each batch as well.
|
||||
We just scale the up down weights to match this vector
|
||||
:return:
|
||||
"""
|
||||
# get state dict
|
||||
state_dict = self.state_dict()
|
||||
dtype = state_dict['lora_up.weight'].dtype
|
||||
device = state_dict['lora_up.weight'].device
|
||||
|
||||
# todo should we do this at fp32?
|
||||
if isinstance(self.normalize_scaler, torch.Tensor):
|
||||
scaler = self.normalize_scaler.clone().detach()
|
||||
else:
|
||||
scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype)
|
||||
|
||||
total_module_scale = scaler / target_normalize_scaler
|
||||
num_modules_layers = 2 # up and down
|
||||
up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \
|
||||
.to(device, dtype=dtype)
|
||||
|
||||
# apply the scaler to the up and down weights
|
||||
for key in state_dict.keys():
|
||||
if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'):
|
||||
# do it inplace do params are updated
|
||||
state_dict[key] *= up_down_scale
|
||||
|
||||
# reset the normalization scaler
|
||||
self.normalize_scaler = target_normalize_scaler
|
||||
|
||||
@torch.no_grad()
|
||||
def merge_out(self: Module, merge_out_weight=1.0):
|
||||
# make sure it is positive
|
||||
@@ -251,6 +311,23 @@ class ToolkitModuleMixin:
|
||||
org_sd["weight"] = weight.to(orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
|
||||
# outputs the same. It is basically a LoRA but with the original module removed
|
||||
|
||||
# if a state dict is passed, use those weights instead of extracting
|
||||
# todo load from state dict
|
||||
network: Network = self.network_ref()
|
||||
lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name)
|
||||
|
||||
extract_mode = lorm_config.extract_mode
|
||||
extract_mode_param = lorm_config.extract_mode_param
|
||||
parameter_threshold = lorm_config.parameter_threshold
|
||||
self.extract_weight(
|
||||
extract_mode=extract_mode,
|
||||
extract_mode_param=extract_mode_param
|
||||
)
|
||||
|
||||
|
||||
class ToolkitNetworkMixin:
|
||||
def __init__(
|
||||
@@ -260,6 +337,8 @@ class ToolkitNetworkMixin:
|
||||
train_unet: Optional[bool] = True,
|
||||
is_sdxl=False,
|
||||
is_v2=False,
|
||||
network_config: Optional[NetworkConfig] = None,
|
||||
is_lorm=False,
|
||||
**kwargs
|
||||
):
|
||||
self.train_text_encoder = train_text_encoder
|
||||
@@ -267,11 +346,14 @@ class ToolkitNetworkMixin:
|
||||
self.is_checkpointing = False
|
||||
self._multiplier: float = 1.0
|
||||
self.is_active: bool = False
|
||||
self._is_normalizing: bool = False
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_merged_in = False
|
||||
# super().__init__(*args, **kwargs)
|
||||
self.is_lorm = is_lorm
|
||||
self.network_config: NetworkConfig = network_config
|
||||
self.module_losses: List[torch.Tensor] = []
|
||||
self.lorm_train_mode: Literal['local', None] = None
|
||||
self.can_merge_in = not is_lorm
|
||||
|
||||
def get_keymap(self: Network):
|
||||
if self.is_sdxl:
|
||||
@@ -443,28 +525,41 @@ class ToolkitNetworkMixin:
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
@property
|
||||
def is_normalizing(self: Network) -> bool:
|
||||
return self._is_normalizing
|
||||
|
||||
@is_normalizing.setter
|
||||
def is_normalizing(self: Network, value: bool):
|
||||
self._is_normalizing = value
|
||||
# for module in self.get_all_modules():
|
||||
# module.is_normalizing = self._is_normalizing
|
||||
|
||||
def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0):
|
||||
for module in self.get_all_modules():
|
||||
module.apply_stored_normalizer(target_normalize_scaler)
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
self.is_merged_in = True
|
||||
for module in self.get_all_modules():
|
||||
module.merge_in(merge_weight)
|
||||
|
||||
def merge_out(self, merge_weight=1.0):
|
||||
def merge_out(self: Network, merge_weight=1.0):
|
||||
if not self.is_merged_in:
|
||||
return
|
||||
self.is_merged_in = False
|
||||
for module in self.get_all_modules():
|
||||
module.merge_out(merge_weight)
|
||||
|
||||
def extract_weight(
|
||||
self: Network,
|
||||
extract_mode: ExtractMode = "existing",
|
||||
extract_mode_param: Union[int, float] = None,
|
||||
):
|
||||
if extract_mode_param is None:
|
||||
raise ValueError("extract_mode_param must be set")
|
||||
for module in tqdm(self.get_all_modules(), desc="Extracting weights"):
|
||||
module.extract_weight(
|
||||
extract_mode=extract_mode,
|
||||
extract_mode_param=extract_mode_param
|
||||
)
|
||||
|
||||
def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None):
|
||||
for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"):
|
||||
module.setup_lorm(state_dict=state_dict)
|
||||
|
||||
def calculate_lorem_parameter_reduction(self):
|
||||
params_reduced = 0
|
||||
for module in self.get_all_modules():
|
||||
num_orig_module_params = count_parameters(module.org_module[0])
|
||||
num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up)
|
||||
params_reduced += (num_orig_module_params - num_lorem_params)
|
||||
|
||||
return params_reduced
|
||||
|
||||
|
||||
@@ -61,12 +61,8 @@ class BlankNetwork:
|
||||
def __init__(self):
|
||||
self.multiplier = 1.0
|
||||
self.is_active = True
|
||||
self.is_normalizing = False
|
||||
self.is_merged_in = False
|
||||
|
||||
def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
self.is_active = True
|
||||
|
||||
@@ -180,11 +176,19 @@ class StableDiffusion:
|
||||
**load_args
|
||||
)
|
||||
else:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
try:
|
||||
pipe = pipln.from_single_file(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error loading model from single file. Trying to load from pretrained")
|
||||
pipe = pipln.from_pretrained(
|
||||
model_path,
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
flush()
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
@@ -277,19 +281,13 @@ class StableDiffusion:
|
||||
# check if we have the same network weight for all samples. If we do, we can merge in th
|
||||
# the network to drastically speed up inference
|
||||
unique_network_weights = set([x.network_multiplier for x in image_configs])
|
||||
if len(unique_network_weights) == 1:
|
||||
if len(unique_network_weights) == 1 and self.network.can_merge_in:
|
||||
can_merge_in = True
|
||||
merge_multiplier = unique_network_weights.pop()
|
||||
network.merge_in(merge_weight=merge_multiplier)
|
||||
else:
|
||||
network = BlankNetwork()
|
||||
|
||||
was_network_normalizing = network.is_normalizing
|
||||
# apply the normalizer if it is normalizing before inference and disable it
|
||||
if network.is_normalizing:
|
||||
network.apply_stored_normalizer()
|
||||
network.is_normalizing = False
|
||||
|
||||
self.save_device_state()
|
||||
self.set_device_state_preset('generate')
|
||||
|
||||
@@ -471,7 +469,6 @@ class StableDiffusion:
|
||||
if self.network is not None:
|
||||
self.network.train()
|
||||
self.network.multiplier = start_multiplier
|
||||
self.network.is_normalizing = was_network_normalizing
|
||||
|
||||
if network.is_merged_in:
|
||||
network.merge_out(merge_multiplier)
|
||||
|
||||
Reference in New Issue
Block a user