From 6f3e0d5af2f97308322ee07cfa3982dcb14e9247 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 28 Oct 2023 08:21:59 -0600 Subject: [PATCH] Improved lorm extraction and training --- .../advanced_generator/PureLoraGenerator.py | 102 +++++++ .../advanced_generator/__init__.py | 19 +- extensions_built_in/sd_trainer/SDTrainer.py | 22 +- jobs/process/BaseSDTrainProcess.py | 73 +++-- toolkit/config_modules.py | 70 ++++- toolkit/lora_special.py | 57 +++- toolkit/lorm.py | 70 +++-- toolkit/lycoris_special.py | 32 +- toolkit/network_mixins.py | 279 ++++++++++++------ toolkit/stable_diffusion_model.py | 31 +- 10 files changed, 559 insertions(+), 196 deletions(-) create mode 100644 extensions_built_in/advanced_generator/PureLoraGenerator.py diff --git a/extensions_built_in/advanced_generator/PureLoraGenerator.py b/extensions_built_in/advanced_generator/PureLoraGenerator.py new file mode 100644 index 00000000..ec19da31 --- /dev/null +++ b/extensions_built_in/advanced_generator/PureLoraGenerator.py @@ -0,0 +1,102 @@ +import os +from collections import OrderedDict + +from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig +from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset +from toolkit.stable_diffusion_model import StableDiffusion +import gc +import torch +from jobs.process import BaseExtensionProcess +from toolkit.train_tools import get_torch_dtype + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +class PureLoraGenerator(BaseExtensionProcess): + + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.output_folder = self.get_conf('output_folder', required=True) + self.device = self.get_conf('device', 'cuda') + self.device_torch = torch.device(self.device) + self.model_config = ModelConfig(**self.get_conf('model', required=True)) + self.generate_config = SampleConfig(**self.get_conf('sample', required=True)) + self.dtype = self.get_conf('dtype', 'float16') + self.torch_dtype = get_torch_dtype(self.dtype) + lorm_config = self.get_conf('lorm', None) + self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None + + self.device_state_preset = get_train_sd_device_state_preset( + device=torch.device(self.device), + ) + + self.progress_bar = None + self.sd = StableDiffusion( + device=self.device, + model_config=self.model_config, + dtype=self.dtype, + ) + + def run(self): + super().run() + print("Loading model...") + with torch.no_grad(): + self.sd.load_model() + self.sd.unet.eval() + self.sd.unet.to(self.device_torch) + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + te.to(self.device_torch) + else: + self.sd.text_encoder.eval() + self.sd.to(self.device_torch) + + print(f"Converting to LoRM UNet") + # replace the unet with LoRMUnet + convert_diffusers_unet_to_lorm( + self.sd.unet, + config=self.lorm_config, + ) + + sample_folder = os.path.join(self.output_folder) + gen_img_config_list = [] + + sample_config = self.generate_config + start_seed = sample_config.seed + current_seed = start_seed + for i in range(len(sample_config.prompts)): + if sample_config.walk_seed: + current_seed = start_seed + i + + filename = f"[time]_[count].{self.generate_config.ext}" + output_path = os.path.join(sample_folder, filename) + prompt = sample_config.prompts[i] + extra_args = {} + gen_img_config_list.append(GenerateImageConfig( + prompt=prompt, # it will autoparse the prompt + width=sample_config.width, + height=sample_config.height, + negative_prompt=sample_config.neg, + seed=current_seed, + guidance_scale=sample_config.guidance_scale, + guidance_rescale=sample_config.guidance_rescale, + num_inference_steps=sample_config.sample_steps, + network_multiplier=sample_config.network_multiplier, + output_path=output_path, + output_ext=sample_config.ext, + adapter_conditioning_scale=sample_config.adapter_conditioning_scale, + **extra_args + )) + + # send to be generated + self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + print("Done generating images") + # cleanup + del self.sd + gc.collect() + torch.cuda.empty_cache() diff --git a/extensions_built_in/advanced_generator/__init__.py b/extensions_built_in/advanced_generator/__init__.py index d811fe89..94a91c6b 100644 --- a/extensions_built_in/advanced_generator/__init__.py +++ b/extensions_built_in/advanced_generator/__init__.py @@ -19,7 +19,24 @@ class AdvancedReferenceGeneratorExtension(Extension): return ReferenceGenerator +# This is for generic training (LoRA, Dreambooth, FineTuning) +class PureLoraGenerator(Extension): + # uid must be unique, it is how the extension is identified + uid = "pure_lora_generator" + + # name is the name of the extension for printing + name = "Pure LoRA Generator" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .PureLoraGenerator import PureLoraGenerator + return PureLoraGenerator + + AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - AdvancedReferenceGeneratorExtension, + AdvancedReferenceGeneratorExtension, PureLoraGenerator ] diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index c85e1d87..269e5892 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -32,7 +32,6 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.inverted_mask_prior: self.do_prior_prediction = True - def before_model_load(self): pass @@ -193,6 +192,15 @@ class SDTrainer(BaseSDTrainProcess): self.network.is_active = was_network_active return prior_pred + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def end_of_training_loop(self): + pass + def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): self.timer.start('preprocess_batch') @@ -331,7 +339,6 @@ class SDTrainer(BaseSDTrainProcess): adapter_images_list = [adapter_images] mask_multiplier_list = [mask_multiplier] - for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier in zip( noisy_latents_list, noise_list, @@ -366,7 +373,8 @@ class SDTrainer(BaseSDTrainProcess): # flush() pred_kwargs = {} - if has_adapter_img and ((self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): + if has_adapter_img and ( + (self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): with torch.set_grad_enabled(self.adapter is not None): adapter = self.adapter if self.adapter else self.assistant_adapter adapter_multiplier = get_adapter_multiplier() @@ -406,8 +414,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) - - + self.before_unet_predict() with self.timer('predict_unet'): noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), @@ -416,6 +423,7 @@ class SDTrainer(BaseSDTrainProcess): guidance_scale=1.0, **pred_kwargs ) + self.after_unet_predict() with self.timer('calculate_loss'): noise = noise.to(self.device_torch, dtype=dtype).detach() @@ -442,7 +450,7 @@ class SDTrainer(BaseSDTrainProcess): loss.backward() torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) - # flush() + # flush() with self.timer('optimizer_step'): # apply gradients @@ -460,4 +468,6 @@ class SDTrainer(BaseSDTrainProcess): {'loss': loss.item()} ) + self.end_of_training_loop() + return loss_dict diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f2a956f6..ce8f5b00 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -7,6 +7,7 @@ from typing import Union, List import numpy as np from diffusers import T2IAdapter +from safetensors.torch import save_file, load_file # from lycoris.config import PRESET from torch.utils.data import DataLoader import torch @@ -18,7 +19,8 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc from toolkit.embedding import Embedding from toolkit.ip_adapter import IPAdapter from toolkit.lora_special import LoRASpecialNetwork -from toolkit.lorm import convert_diffusers_unet_to_lorm +from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \ + lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE from toolkit.lycoris_special import LycorisSpecialNetwork from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer @@ -128,6 +130,9 @@ class BaseSDTrainProcess(BaseTrainProcess): is_training_adapter = self.adapter_config is not None and self.adapter_config.train self.do_lorm = self.get_conf('do_lorm', False) + self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio') + self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25) + # 'ratio', 0.25) # get the device state preset based on what we are training self.train_device_state_preset = get_train_sd_device_state_preset( @@ -300,9 +305,6 @@ class BaseSDTrainProcess(BaseTrainProcess): file_path = os.path.join(self.save_root, filename) prev_multiplier = self.network.multiplier self.network.multiplier = 1.0 - if self.network_config.normalize: - # apply the normalization - self.network.apply_stored_normalizer() # if we are doing embedding training as well, add that embedding_dict = self.embedding.state_dict() if self.embedding else None @@ -427,6 +429,21 @@ class BaseSDTrainProcess(BaseTrainProcess): print("load_weights not implemented for non-network models") return None + def load_lorm(self): + latest_save_path = self.get_latest_save_path() + if latest_save_path is not None: + # hacky way to reload weights for now + # todo, do this + state_dict = load_file(latest_save_path, device=self.device) + self.sd.unet.load_state_dict(state_dict) + + meta = load_metadata_from_safetensors(latest_save_path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") + # def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): # self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch) # sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype) @@ -610,7 +627,6 @@ class BaseSDTrainProcess(BaseTrainProcess): batch.mask_tensor = double_up_tensor(batch.mask_tensor) batch.control_tensor = double_up_tensor(batch.control_tensor) - # remove grads for these noisy_latents.requires_grad = False noisy_latents = noisy_latents.detach() @@ -712,15 +728,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # run base sd process run self.sd.load_model() - if self.do_lorm: - train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27) - for module in train_modules: - p = module.parameters() - for param in p: - param.requires_grad_(True) - params.append(param) - - dtype = get_torch_dtype(self.train_config.dtype) # model is loaded from BaseSDProcess @@ -783,14 +790,20 @@ class BaseSDTrainProcess(BaseTrainProcess): if not self.is_fine_tuning: if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? - + network_kwargs = {} is_lycoris = False + is_lorm = self.network_config.type.lower() == 'lorm' # default to LoCON if there are any conv layers or if it is named NetworkClass = LoRASpecialNetwork if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris': NetworkClass = LycorisSpecialNetwork is_lycoris = True + if is_lorm: + network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains + network_kwargs['parameter_threshold'] = lorm_parameter_threshold + network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE + # if is_lycoris: # preset = PRESET['full'] # NetworkClass.apply_preset(preset) @@ -810,6 +823,10 @@ class BaseSDTrainProcess(BaseTrainProcess): dropout=self.network_config.dropout, use_text_encoder_1=self.model_config.use_text_encoder_1, use_text_encoder_2=self.model_config.use_text_encoder_2, + use_bias=is_lorm, + is_lorm=is_lorm, + network_config=self.network_config, + **network_kwargs ) self.network.force_to(self.device_torch, dtype=dtype) @@ -824,6 +841,20 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.train_unet ) + if is_lorm: + self.network.is_lorm = True + # make sure it is on the right device + self.sd.unet.to(self.sd.device, dtype=dtype) + original_unet_param_count = count_parameters(self.sd.unet) + self.network.setup_lorm() + new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction() + + print_lorm_extract_details( + start_num_params=original_unet_param_count, + end_num_params=new_unet_param_count, + num_replaced=len(self.network.get_all_modules()), + ) + self.network.prepare_grad_etc(text_encoder, unet) flush() @@ -846,9 +877,6 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.train_config.gradient_checkpointing: self.network.enable_gradient_checkpointing() - # set the network to normalize if we are - self.network.is_normalizing = self.network_config.normalize - lora_name = self.name # need to adapt name so they are not mixed up if self.named_lora: @@ -915,7 +943,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # set the device state preset before getting params self.sd.set_device_state(self.train_device_state_preset) - # params = self.get_params() if len(params) == 0: # will only return savable weights and ones with grad @@ -1050,9 +1077,6 @@ class BaseSDTrainProcess(BaseTrainProcess): else: batch = None - # turn on normalization if we are using it and it is not on - if self.network is not None and self.network_config.normalize and not self.network.is_normalizing: - self.network.is_normalizing = True # flush() ### HOOK ### self.timer.start('train_loop') @@ -1078,11 +1102,6 @@ class BaseSDTrainProcess(BaseTrainProcess): self.progress_bar.set_postfix_str(prog_bar_string) - # apply network normalizer if we are using it, not on regularization steps - if self.network is not None and self.network.is_normalizing and not is_reg_step: - with self.timer('apply_normalizer'): - self.network.apply_stored_normalizer() - # if the batch is a DataLoaderBatchDTO, then we need to clean it up if isinstance(batch, DataLoaderBatchDTO): with self.timer('batch_cleanup'): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 60032160..b307c3e0 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -40,7 +40,48 @@ class SampleConfig: self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) -NetworkType = Literal['lora', 'locon'] +class LormModuleSettingsConfig: + def __init__(self, **kwargs): + self.contains: str = kwargs.get('contains', '4nt$3') + self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + # min num parameters to attach to + self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) + self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) + + +class LoRMConfig: + def __init__(self, **kwargs): + self.extract_mode: str = kwargs.get('extract_mode', 'ratio') + self.extract_mode_param: dict = kwargs.get('extract_mode_param', 0.25) + self.parameter_threshold: int = kwargs.get('parameter_threshold', 0) + module_settings = kwargs.get('module_settings', []) + default_module_settings = { + 'extract_mode': self.extract_mode, + 'extract_mode_param': self.extract_mode_param, + 'parameter_threshold': self.parameter_threshold, + } + module_settings = [{**default_module_settings, **module_setting, } for module_setting in module_settings] + self.module_settings: List[LormModuleSettingsConfig] = [LormModuleSettingsConfig(**module_setting) for + module_setting in module_settings] + + def get_config_for_module(self, block_name): + for setting in self.module_settings: + contain_pieces = setting.contains.split('|') + if all(contain_piece in block_name for contain_piece in contain_pieces): + return setting + # try replacing the . with _ + contain_pieces = setting.contains.replace('.', '_').split('|') + if all(contain_piece in block_name for contain_piece in contain_pieces): + return setting + # do default + return LormModuleSettingsConfig(**{ + 'extract_mode': self.extract_mode, + 'extract_mode_param': self.extract_mode_param, + 'parameter_threshold': self.parameter_threshold, + }) + + +NetworkType = Literal['lora', 'locon', 'lorm'] class NetworkConfig: @@ -58,12 +99,22 @@ class NetworkConfig: self.alpha: float = kwargs.get('alpha', 1.0) self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) - self.normalize = kwargs.get('normalize', False) self.dropout: Union[float, None] = kwargs.get('dropout', None) + self.lorm_config: Union[LoRMConfig, None] = None + lorm = kwargs.get('lorm', None) + if lorm is not None: + self.lorm_config: LoRMConfig = LoRMConfig(**lorm) + + if self.type == 'lorm': + # set linear to arbitrary values so it makes them + self.linear = 4 + self.rank = 4 + AdapterTypes = Literal['t2i', 'ip', 'ip+'] + class AdapterConfig: def __init__(self, **kwargs): self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip @@ -90,6 +141,7 @@ class EmbeddingConfig: ContentOrStyleType = Literal['balanced', 'style', 'content'] LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] + class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') @@ -138,7 +190,8 @@ class TrainConfig: match_adapter_assist = kwargs.get('match_adapter_assist', False) self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) - self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise + self.loss_target: LossTarget = kwargs.get('loss_target', + 'noise') # noise, source, unaugmented, differential_noise # When a mask is passed in a dataset, and this is true, # we will predict noise without a the LoRa network and use the prediction as a target for @@ -151,7 +204,6 @@ class TrainConfig: self.match_adapter_chance = 1.0 - class ModelConfig: def __init__(self, **kwargs): self.name_or_path: str = kwargs.get('name_or_path', None) @@ -216,7 +268,7 @@ class SliderConfig: self.prompt_file: str = kwargs.get('prompt_file', None) self.prompt_tensors: str = kwargs.get('prompt_tensors', None) self.batch_full_slide: bool = kwargs.get('batch_full_slide', True) - self.use_adapter: bool = kwargs.get('use_adapter', None) # depth + self.use_adapter: bool = kwargs.get('use_adapter', None) # depth self.adapter_img_dir = kwargs.get('adapter_img_dir', None) self.high_ram = kwargs.get('high_ram', False) @@ -267,9 +319,11 @@ class DatasetConfig: self.augments: List[str] = kwargs.get('augments', []) self.control_path: str = kwargs.get('control_path', None) # depth maps, etc self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask - self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) + self.mask_path: str = kwargs.get('mask_path', + None) # focus mask (black and white. White has higher loss than black) self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 - self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes + self.poi: Union[str, None] = kwargs.get('poi', + None) # if one is set and in json data, will be used as auto crop scale point of interes self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) @@ -525,4 +579,4 @@ class GenerateImageConfig: unconditional_prompt_embeds: Optional[PromptEmbeds] = None, ): # this is called after prompt embeds are encoded. We can override them in the future here - pass \ No newline at end of file + pass diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 21d5cc5c..8bf1e063 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -7,7 +7,9 @@ from typing import List, Optional, Dict, Type, Union import torch from transformers import CLIPTextModel -from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin +from .config_modules import NetworkConfig +from .lorm import count_parameters +from .network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin from .paths import SD_SCRIPTS_ROOT sys.path.append(SD_SCRIPTS_ROOT) @@ -30,7 +32,7 @@ CONV_MODULES = [ 'LoRACompatibleConv' ] -class LoRAModule(ToolkitModuleMixin, torch.nn.Module): +class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. """ @@ -46,13 +48,17 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): rank_dropout=None, module_dropout=None, network: 'LoRASpecialNetwork' = None, - parent=None, + use_bias: bool = False, **kwargs ): """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__(network=network) + ToolkitModuleMixin.__init__(self, network=network) + torch.nn.Module.__init__(self) self.lora_name = lora_name self.scalar = torch.tensor(1.0) + # check if parent has bias. if not force use_bias to False + if org_module.bias is None: + use_bias = False if org_module.__class__.__name__ in CONV_MODULES: in_dim = org_module.in_channels @@ -73,10 +79,10 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): stride = org_module.stride padding = org_module.padding self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) else: self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -95,8 +101,6 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.is_checkpointing = False - self.is_normalizing = False - self.normalize_scaler = 1.0 def apply_to(self): self.org_forward = self.org_module[0].forward @@ -143,6 +147,13 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, + use_bias: bool = False, + is_lorm: bool = False, + ignore_if_contains = None, + parameter_threshold: float = 0.0, + target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, + target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, + **kwargs ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -154,7 +165,18 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): """ # call the parent of the parent we are replacing (LoRANetwork) init torch.nn.Module.__init__(self) - + ToolkitNetworkMixin.__init__( + self, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + is_sdxl=is_sdxl, + is_v2=is_v2, + is_lorm=is_lorm, + **kwargs + ) + if ignore_if_contains is None: + ignore_if_contains = [] + self.ignore_if_contains = ignore_if_contains self.lora_dim = lora_dim self.alpha = alpha self.conv_lora_dim = conv_lora_dim @@ -165,13 +187,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_checkpointing = False self._multiplier: float = 1.0 self.is_active: bool = False - self._is_normalizing: bool = False self.torch_multiplier = None # triggers the state updates self.multiplier = multiplier self.is_sdxl = is_sdxl self.is_v2 = is_v2 - self.is_merged_in = False if modules_dim is not None: print(f"create LoRA network from weights") @@ -217,7 +237,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_conv2d = child_module.__class__.__name__ in CONV_MODULES is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - if is_linear or is_conv2d: + skip = False + if any([word in child_name for word in self.ignore_if_contains]): + skip = True + + # see if it is over threshold + if count_parameters(child_module) < parameter_threshold: + skip = True + + if (is_linear or is_conv2d) and not skip: lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") @@ -265,6 +293,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): module_dropout=module_dropout, network=self, parent=module, + use_bias=use_bias, ) loras.append(lora) return loras, skipped @@ -295,9 +324,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + target_modules = target_lin_modules if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: - target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + target_modules += target_conv_modules if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) diff --git a/toolkit/lorm.py b/toolkit/lorm.py index b6acb6da..782f1a86 100644 --- a/toolkit/lorm.py +++ b/toolkit/lorm.py @@ -6,6 +6,8 @@ from diffusers import UNet2DConditionModel from torch import Tensor from tqdm import tqdm +from toolkit.config_modules import LoRMConfig + conv = nn.Conv2d lin = nn.Linear _size_2_t = Union[int, Tuple[int, int]] @@ -29,12 +31,13 @@ CONV_MODULES = [ UNET_TARGET_REPLACE_MODULE = [ "Transformer2DModel", - # "BasicTransformerBlock", # "ResnetBlock2D", "Downsample2D", "Upsample2D", ] +LORM_TARGET_REPLACE_MODULE = UNET_TARGET_REPLACE_MODULE + UNET_TARGET_REPLACE_NAME = [ "conv_in", "conv_out", @@ -279,13 +282,38 @@ def compute_optimal_bias(original_module, linear_down, linear_up, X): return optimal_bias +def format_with_commas(n): + return f"{n:,}" + + +def print_lorm_extract_details( + start_num_params: int, + end_num_params: int, + num_replaced: int, +): + start_formatted = format_with_commas(start_num_params) + end_formatted = format_with_commas(end_num_params) + num_replaced_formatted = format_with_commas(num_replaced) + + width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) + + print(f"Convert UNet result:") + print(f" - converted: {num_replaced:>{width},} modules") + print(f" - start: {start_num_params:>{width},} params") + print(f" - end: {end_num_params:>{width},} params") + + +lorm_ignore_if_contains = [ + 'proj_out', 'proj_in', +] + +lorm_parameter_threshold = 1000000 + + @torch.no_grad() def convert_diffusers_unet_to_lorm( unet: UNet2DConditionModel, - extract_mode: ExtractMode = "percentile", - mode_param: Union[int, float] = 0.5, - parameter_threshold: int = 500000, - # parameter_threshold: int = 1500000 + config: LoRMConfig, ): print('Converting UNet to LoRM UNet') start_num_params = count_parameters(unet) @@ -299,8 +327,6 @@ def convert_diffusers_unet_to_lorm( ignore_if_contains = [ 'proj_out', 'proj_in', ] - def format_with_commas(n): - return f"{n:,}" for name, module in named_modules: module_name = module.__class__.__name__ @@ -311,6 +337,13 @@ def convert_diffusers_unet_to_lorm( combined_name = combined_name = f"{name}.{child_name}" # if child_module.__class__.__name__ in LINEAR_MODULES and child_module.bias is None: # pass + + lorm_config = config.get_config_for_module(combined_name) + + extract_mode = lorm_config.extract_mode + extract_mode_param = lorm_config.extract_mode_param + parameter_threshold = lorm_config.parameter_threshold + if any([word in child_name for word in ignore_if_contains]): pass @@ -322,7 +355,7 @@ def convert_diffusers_unet_to_lorm( down_weight, up_weight, lora_dim, diff = extract_linear( weight=child_module.weight.clone().detach().float(), mode=extract_mode, - mode_param=mode_param, + mode_param=extract_mode_param, device=child_module.weight.device, ) down_weight = down_weight.to(dtype=dtype) @@ -362,7 +395,7 @@ def convert_diffusers_unet_to_lorm( down_weight, up_weight, lora_dim, diff = extract_conv( weight=child_module.weight.clone().detach().float(), mode=extract_mode, - mode_param=mode_param, + mode_param=extract_mode_param, device=child_module.weight.device, ) down_weight = down_weight.to(dtype=dtype) @@ -395,30 +428,25 @@ def convert_diffusers_unet_to_lorm( replace_module_by_path(unet, combined_name, new_module) converted_modules.append(new_module) num_replaced += 1 - layer_names_replaced.append(f"{combined_name} - {format_with_commas(count_parameters(child_module))}") + layer_names_replaced.append( + f"{combined_name} - {format_with_commas(count_parameters(child_module))}") pbar.update(1) pbar.close() end_num_params = count_parameters(unet) - start_formatted = format_with_commas(start_num_params) - end_formatted = format_with_commas(end_num_params) - num_replaced_formatted = format_with_commas(num_replaced) - - width = max(len(start_formatted), len(end_formatted), len(num_replaced_formatted)) - def sorting_key(s): # Extract the number part, remove commas, and convert to integer return int(s.split("-")[1].strip().replace(",", "")) sorted_layer_names_replaced = sorted(layer_names_replaced, key=sorting_key, reverse=True) - for layer_name in sorted_layer_names_replaced: print(layer_name) - print(f"Convert UNet result:") - print(f" - converted: {num_replaced:>{width},} modules") - print(f" - start: {start_num_params:>{width},} params") - print(f" - end: {end_num_params:>{width},} params") + print_lorm_extract_details( + start_num_params=start_num_params, + end_num_params=end_num_params, + num_replaced=num_replaced, + ) return converted_modules diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index f38556e7..84021b49 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -8,20 +8,19 @@ from lycoris.modules.glora import GLoRAModule from torch import nn from transformers import CLIPTextModel from torch.nn import functional as F -from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin +from toolkit.network_mixins import ToolkitNetworkMixin, ToolkitModuleMixin, ExtractableModuleMixin # diffusers specific stuff LINEAR_MODULES = [ 'Linear', 'LoRACompatibleLinear' - # 'GroupNorm', ] CONV_MODULES = [ 'Conv2d', 'LoRACompatibleConv' ] -class LoConSpecialModule(ToolkitModuleMixin, LoConModule): +class LoConSpecialModule(ToolkitModuleMixin, LoConModule, ExtractableModuleMixin): def __init__( self, lora_name, org_module: nn.Module, @@ -30,18 +29,20 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): dropout=0., rank_dropout=0., module_dropout=0., use_cp=False, network: 'LycorisSpecialNetwork' = None, - parent=None, + use_bias=False, **kwargs, ): """ if alpha == 0 or None, alpha is rank (no scaling). """ # call super of super + ToolkitModuleMixin.__init__(self, network=network) torch.nn.Module.__init__(self) - # call super of - super().__init__(call_super_init=False, network=network) self.lora_name = lora_name self.lora_dim = lora_dim self.cp = False + # check if parent has bias. if not force use_bias to False + if org_module.bias is None: + use_bias = False self.scalar = nn.Parameter(torch.tensor(0.0)) orig_module_name = org_module.__class__.__name__ @@ -61,7 +62,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): self.cp = True else: self.lora_down = nn.Conv2d(in_dim, lora_dim, k_size, stride, padding, bias=False) - self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False) + self.lora_up = nn.Conv2d(lora_dim, out_dim, (1, 1), bias=use_bias) elif orig_module_name in LINEAR_MODULES: self.isconv = False self.down_op = F.linear @@ -74,7 +75,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): in_dim = org_module.in_features out_dim = org_module.out_features self.lora_down = nn.Linear(in_dim, lora_dim, bias=False) - self.lora_up = nn.Linear(lora_dim, out_dim, bias=False) + self.lora_up = nn.Linear(lora_dim, out_dim, bias=use_bias) else: raise NotImplementedError self.shape = org_module.weight.shape @@ -159,10 +160,16 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): train_text_encoder: bool = True, use_text_encoder_1: bool = True, use_text_encoder_2: bool = True, + use_bias: bool = False, + is_lorm: bool = False, **kwargs, ) -> None: # call ToolkitNetworkMixin super - super().__init__( + ToolkitNetworkMixin.__init__( + self, + train_text_encoder=train_text_encoder, + train_unet=train_unet, + is_lorm=is_lorm, **kwargs ) # call the parent of the parent LycorisNetwork @@ -217,7 +224,6 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): loras = [] # remove this named_modules = root_module.named_modules() - modules = root_module.modules() # add a few to tthe generator for name, module in named_modules: @@ -241,6 +247,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_cp, network=self, parent=module, + use_bias=use_bias, **kwargs ) elif child_module.__class__.__name__ in CONV_MODULES: @@ -253,6 +260,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_cp, network=self, parent=module, + use_bias=use_bias, **kwargs ) elif conv_lora_dim > 0: @@ -263,6 +271,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_cp, network=self, parent=module, + use_bias=use_bias, **kwargs ) else: @@ -285,6 +294,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_cp, parent=module, network=self, + use_bias=use_bias, **kwargs ) elif module.__class__.__name__ == 'Conv2d': @@ -297,6 +307,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_cp, network=self, parent=module, + use_bias=use_bias, **kwargs ) elif conv_lora_dim > 0: @@ -307,6 +318,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): use_cp, network=self, parent=module, + use_bias=use_bias, **kwargs ) else: diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index c62de64b..172925d2 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -1,17 +1,23 @@ import json import os from collections import OrderedDict -from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any +from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any, Literal import torch from torch import nn import weakref + +from tqdm import tqdm + +from toolkit.config_modules import NetworkConfig +from toolkit.lorm import extract_conv, extract_linear, count_parameters from toolkit.metadata import add_model_hash_to_meta from toolkit.paths import KEYMAPS_ROOT if TYPE_CHECKING: from toolkit.lycoris_special import LycorisSpecialNetwork, LoConSpecialModule from toolkit.lora_special import LoRASpecialNetwork, LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion Network = Union['LycorisSpecialNetwork', 'LoRASpecialNetwork'] Module = Union['LoConSpecialModule', 'LoRAModule'] @@ -26,6 +32,15 @@ CONV_MODULES = [ 'LoRACompatibleConv' ] +ExtractMode = Union[ + 'existing' + 'fixed', + 'threshold', + 'ratio', + 'quantile', + 'percentage' +] + def broadcast_and_multiply(tensor, multiplier): # Determine the number of dimensions required @@ -41,20 +56,101 @@ def broadcast_and_multiply(tensor, multiplier): return result +def add_bias(tensor, bias): + if bias is None: + return tensor + # add batch dim + bias = bias.unsqueeze(0) + bias = torch.cat([bias] * tensor.size(0), dim=0) + # Determine the number of dimensions required + num_extra_dims = tensor.dim() - bias.dim() + + # Unsqueezing the tensor to match the dimensionality + for _ in range(num_extra_dims): + bias = bias.unsqueeze(-1) + + # we may need to swap -1 for -2 + if bias.size(1) != tensor.size(1): + if len(bias.size()) == 3: + bias = bias.permute(0, 2, 1) + elif len(bias.size()) == 4: + bias = bias.permute(0, 3, 1, 2) + + # Multiplying the broadcasted tensor with the output tensor + try: + result = tensor + bias + except RuntimeError as e: + print(e) + print(tensor.size()) + print(bias.size()) + raise e + + return result + + +class ExtractableModuleMixin: + def extract_weight( + self: Module, + extract_mode: ExtractMode = "existing", + extract_mode_param: Union[int, float] = None, + ): + device = self.lora_down.weight.device + weight_to_extract = self.org_module[0].weight + if extract_mode == "existing": + extract_mode = 'fixed' + extract_mode_param = self.lora_dim + + if self.org_module[0].__class__.__name__ in CONV_MODULES: + # do conv extraction + down_weight, up_weight, new_dim, diff = extract_conv( + weight=weight_to_extract.clone().detach().float(), + mode=extract_mode, + mode_param=extract_mode_param, + device=device + ) + + elif self.org_module[0].__class__.__name__ in LINEAR_MODULES: + # do linear extraction + down_weight, up_weight, new_dim, diff = extract_linear( + weight=weight_to_extract.clone().detach().float(), + mode=extract_mode, + mode_param=extract_mode_param, + device=device, + ) + else: + raise ValueError(f"Unknown module type: {self.org_module[0].__class__.__name__}") + + self.lora_dim = new_dim + + # inject weights into the param + self.lora_down.weight.data = down_weight.to(self.lora_down.weight.dtype).clone().detach() + self.lora_up.weight.data = up_weight.to(self.lora_up.weight.dtype).clone().detach() + + # copy bias if we have one and are using them + if self.org_module[0].bias is not None and self.lora_up.bias is not None: + self.lora_up.bias.data = self.org_module[0].bias.data.clone().detach() + + # set up alphas + self.alpha = (self.alpha * 0) + down_weight.shape[0] + self.scale = self.alpha / self.lora_dim + + # assign them + + # handle trainable scaler method locon does + if hasattr(self, 'scalar'): + # scaler is a parameter update the value with 1.0 + self.scalar.data = torch.tensor(1.0).to(self.scalar.device, self.scalar.dtype) + + class ToolkitModuleMixin: def __init__( self: Module, *args, network: Network, - call_super_init: bool = True, **kwargs ): - if call_super_init: - super().__init__(*args, **kwargs) self.network_ref: weakref.ref = weakref.ref(network) self.is_checkpointing = False - # self.is_normalizing = False - self.normalize_scaler = 1.0 self._multiplier: Union[float, list, torch.Tensor] = None def _call_forward(self: Module, x): @@ -100,11 +196,40 @@ class ToolkitModuleMixin: return lx * scale - # this may get an additional positional arg or not + + def lorm_forward(self: Network, x, *args, **kwargs): + network: Network = self.network_ref() + if not network.is_active: + return self.org_forward(x, *args, **kwargs) + + if network.lorm_train_mode == 'local': + # we are going to predict input with both and do a loss on them + inputs = x.detach() + with torch.no_grad(): + # get the local prediction + target_pred = self.org_forward(inputs, *args, **kwargs).detach() + with torch.set_grad_enabled(True): + # make a prediction with the lorm + lorm_pred = self.lora_up(self.lora_down(inputs.requires_grad_(True))) + + local_loss = torch.nn.functional.mse_loss(target_pred.float(), lorm_pred.float()) + # backpropr + local_loss.backward() + + network.module_losses.append(local_loss.detach()) + # return the original as we dont want our trainer to affect ones down the line + return target_pred + + else: + return self.lora_up(self.lora_down(x)) def forward(self: Module, x, *args, **kwargs): skip = False - network = self.network_ref() + network: Network = self.network_ref() + if network.is_lorm: + # we are doing lorm + return self.lorm_forward(x, *args, **kwargs) + # skip if not active if not network.is_active: skip = True @@ -130,40 +255,9 @@ class ToolkitModuleMixin: if lora_output_batch_size != multiplier_batch_size: num_interleaves = lora_output_batch_size // multiplier_batch_size multiplier = multiplier.repeat_interleave(num_interleaves) - # multiplier = 1.0 - if self.network_ref().is_normalizing: - with torch.no_grad(): - - # do this calculation without set multiplier and instead use same polarity, but with 1.0 multiplier - if isinstance(multiplier, torch.Tensor): - norm_multiplier = multiplier.clone().detach() * 10 - norm_multiplier = norm_multiplier.clamp(min=-1.0, max=1.0) - else: - norm_multiplier = multiplier - - # get a dim array from orig forward that had index of all dimensions except the batch and channel - - # Calculate the target magnitude for the combined output - orig_max = torch.max(torch.abs(org_forwarded)) - - # Calculate the additional increase in magnitude that lora_output would introduce - potential_max_increase = torch.max( - torch.abs(org_forwarded + lora_output * norm_multiplier) - torch.abs(org_forwarded)) - - epsilon = 1e-6 # Small constant to avoid division by zero - - # Calculate the scaling factor for the lora_output - # to ensure that the potential increase in magnitude doesn't change the original max - normalize_scaler = orig_max / (orig_max + potential_max_increase + epsilon) - normalize_scaler = normalize_scaler.detach() - - # save the scaler so it can be applied later - self.normalize_scaler = normalize_scaler.clone().detach() - - lora_output = lora_output * normalize_scaler - - return org_forwarded + broadcast_and_multiply(lora_output, multiplier) + x = org_forwarded + broadcast_and_multiply(lora_output, multiplier) + return x def enable_gradient_checkpointing(self: Module): self.is_checkpointing = True @@ -171,40 +265,6 @@ class ToolkitModuleMixin: def disable_gradient_checkpointing(self: Module): self.is_checkpointing = False - @torch.no_grad() - def apply_stored_normalizer(self: Module, target_normalize_scaler: float = 1.0): - """ - Applied the previous normalization calculation to the module. - This must be called before saving or normalization will be lost. - It is probably best to call after each batch as well. - We just scale the up down weights to match this vector - :return: - """ - # get state dict - state_dict = self.state_dict() - dtype = state_dict['lora_up.weight'].dtype - device = state_dict['lora_up.weight'].device - - # todo should we do this at fp32? - if isinstance(self.normalize_scaler, torch.Tensor): - scaler = self.normalize_scaler.clone().detach() - else: - scaler = torch.tensor(self.normalize_scaler).to(device, dtype=dtype) - - total_module_scale = scaler / target_normalize_scaler - num_modules_layers = 2 # up and down - up_down_scale = torch.pow(total_module_scale, 1.0 / num_modules_layers) \ - .to(device, dtype=dtype) - - # apply the scaler to the up and down weights - for key in state_dict.keys(): - if key.endswith('.lora_up.weight') or key.endswith('.lora_down.weight'): - # do it inplace do params are updated - state_dict[key] *= up_down_scale - - # reset the normalization scaler - self.normalize_scaler = target_normalize_scaler - @torch.no_grad() def merge_out(self: Module, merge_out_weight=1.0): # make sure it is positive @@ -251,6 +311,23 @@ class ToolkitModuleMixin: org_sd["weight"] = weight.to(orig_dtype) self.org_module[0].load_state_dict(org_sd) + def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): + # LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and + # outputs the same. It is basically a LoRA but with the original module removed + + # if a state dict is passed, use those weights instead of extracting + # todo load from state dict + network: Network = self.network_ref() + lorm_config = network.network_config.lorm_config.get_config_for_module(self.lora_name) + + extract_mode = lorm_config.extract_mode + extract_mode_param = lorm_config.extract_mode_param + parameter_threshold = lorm_config.parameter_threshold + self.extract_weight( + extract_mode=extract_mode, + extract_mode_param=extract_mode_param + ) + class ToolkitNetworkMixin: def __init__( @@ -260,6 +337,8 @@ class ToolkitNetworkMixin: train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, + network_config: Optional[NetworkConfig] = None, + is_lorm=False, **kwargs ): self.train_text_encoder = train_text_encoder @@ -267,11 +346,14 @@ class ToolkitNetworkMixin: self.is_checkpointing = False self._multiplier: float = 1.0 self.is_active: bool = False - self._is_normalizing: bool = False self.is_sdxl = is_sdxl self.is_v2 = is_v2 self.is_merged_in = False - # super().__init__(*args, **kwargs) + self.is_lorm = is_lorm + self.network_config: NetworkConfig = network_config + self.module_losses: List[torch.Tensor] = [] + self.lorm_train_mode: Literal['local', None] = None + self.can_merge_in = not is_lorm def get_keymap(self: Network): if self.is_sdxl: @@ -443,28 +525,41 @@ class ToolkitNetworkMixin: self.is_checkpointing = False self._update_checkpointing() - @property - def is_normalizing(self: Network) -> bool: - return self._is_normalizing - - @is_normalizing.setter - def is_normalizing(self: Network, value: bool): - self._is_normalizing = value - # for module in self.get_all_modules(): - # module.is_normalizing = self._is_normalizing - - def apply_stored_normalizer(self: Network, target_normalize_scaler: float = 1.0): - for module in self.get_all_modules(): - module.apply_stored_normalizer(target_normalize_scaler) - def merge_in(self, merge_weight=1.0): self.is_merged_in = True for module in self.get_all_modules(): module.merge_in(merge_weight) - def merge_out(self, merge_weight=1.0): + def merge_out(self: Network, merge_weight=1.0): if not self.is_merged_in: return self.is_merged_in = False for module in self.get_all_modules(): module.merge_out(merge_weight) + + def extract_weight( + self: Network, + extract_mode: ExtractMode = "existing", + extract_mode_param: Union[int, float] = None, + ): + if extract_mode_param is None: + raise ValueError("extract_mode_param must be set") + for module in tqdm(self.get_all_modules(), desc="Extracting weights"): + module.extract_weight( + extract_mode=extract_mode, + extract_mode_param=extract_mode_param + ) + + def setup_lorm(self: Network, state_dict: Optional[Dict[str, Any]] = None): + for module in tqdm(self.get_all_modules(), desc="Extracting LoRM"): + module.setup_lorm(state_dict=state_dict) + + def calculate_lorem_parameter_reduction(self): + params_reduced = 0 + for module in self.get_all_modules(): + num_orig_module_params = count_parameters(module.org_module[0]) + num_lorem_params = count_parameters(module.lora_down) + count_parameters(module.lora_up) + params_reduced += (num_orig_module_params - num_lorem_params) + + return params_reduced + diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index ee719037..0a732f4e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -61,12 +61,8 @@ class BlankNetwork: def __init__(self): self.multiplier = 1.0 self.is_active = True - self.is_normalizing = False self.is_merged_in = False - def apply_stored_normalizer(self, target_normalize_scaler: float = 1.0): - pass - def __enter__(self): self.is_active = True @@ -180,11 +176,19 @@ class StableDiffusion: **load_args ) else: - pipe = pipln.from_single_file( - model_path, - device=self.device_torch, - torch_dtype=self.torch_dtype, - ) + try: + pipe = pipln.from_single_file( + model_path, + device=self.device_torch, + torch_dtype=self.torch_dtype, + ) + except Exception as e: + print("Error loading model from single file. Trying to load from pretrained") + pipe = pipln.from_pretrained( + model_path, + device=self.device_torch, + torch_dtype=self.torch_dtype, + ) flush() text_encoders = [pipe.text_encoder, pipe.text_encoder_2] @@ -277,19 +281,13 @@ class StableDiffusion: # check if we have the same network weight for all samples. If we do, we can merge in th # the network to drastically speed up inference unique_network_weights = set([x.network_multiplier for x in image_configs]) - if len(unique_network_weights) == 1: + if len(unique_network_weights) == 1 and self.network.can_merge_in: can_merge_in = True merge_multiplier = unique_network_weights.pop() network.merge_in(merge_weight=merge_multiplier) else: network = BlankNetwork() - was_network_normalizing = network.is_normalizing - # apply the normalizer if it is normalizing before inference and disable it - if network.is_normalizing: - network.apply_stored_normalizer() - network.is_normalizing = False - self.save_device_state() self.set_device_state_preset('generate') @@ -471,7 +469,6 @@ class StableDiffusion: if self.network is not None: self.network.train() self.network.multiplier = start_multiplier - self.network.is_normalizing = was_network_normalizing if network.is_merged_in: network.merge_out(merge_multiplier)