From 569d7464d55b061dd365b5b13b40d3839858fbba Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 14 Sep 2023 08:31:54 -0600 Subject: [PATCH] implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar --- extensions_built_in/sd_trainer/SDTrainer.py | 3 +- jobs/process/BaseSDTrainProcess.py | 66 ++++++++------------- toolkit/config_modules.py | 1 + toolkit/lora_special.py | 8 ++- toolkit/lycoris_special.py | 11 +++- toolkit/network_mixins.py | 55 +++++++++++------ toolkit/progress_bar.py | 5 +- toolkit/sd_device_states_presets.py | 59 ++++++++++++++++++ toolkit/stable_diffusion_model.py | 56 ++++++++++------- 9 files changed, 173 insertions(+), 91 deletions(-) create mode 100644 toolkit/sd_device_states_presets.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 2baf7f47..522e5279 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -29,6 +29,7 @@ class SDTrainer(BaseSDTrainProcess): else: # offload it. Already cached self.sd.vae.to('cpu') + flush() def hook_train_loop(self, batch): @@ -110,7 +111,5 @@ class SDTrainer(BaseSDTrainProcess): loss_dict = OrderedDict( {'loss': loss.item()} ) - # reset network multiplier - network.multiplier = 1.0 return loss_dict diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index c5e6d71e..872b7305 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -7,6 +7,8 @@ from typing import Union # from lycoris.config import PRESET from torch.utils.data import DataLoader +import torch +import torch.backends.cuda from toolkit.basic import value_map from toolkit.data_loader import get_dataloader_from_datasets @@ -21,6 +23,7 @@ from toolkit.progress_bar import ToolkitProgressBar from toolkit.sampler import get_sampler from toolkit.scheduler import get_lr_scheduler +from toolkit.sd_device_states_presets import get_train_sd_device_state_preset from toolkit.stable_diffusion_model import StableDiffusion from jobs.process import BaseTrainProcess @@ -28,7 +31,6 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete from toolkit.train_tools import get_torch_dtype import gc -import torch from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ @@ -135,6 +137,16 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network: Union[Network, None] = None self.embedding: Union[Embedding, None] = None + # get the device state preset based on what we are training + self.train_device_state_preset = get_train_sd_device_state_preset( + device=self.device_torch, + train_unet=self.train_config.train_unet, + train_text_encoder=self.train_config.train_text_encoder, + cached_latents=self.is_latents_cached, + train_lora=self.network_config is not None, + train_embedding=self.embed_config is not None, + ) + def sample(self, step=None, is_first=False): sample_folder = os.path.join(self.save_root, 'samples') gen_img_config_list = [] @@ -477,6 +489,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # if it has it if hasattr(te, 'enable_xformers_memory_efficient_attention'): te.enable_xformers_memory_efficient_attention() + if self.train_config.sdp: + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) if self.train_config.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -513,7 +529,8 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.datasets is not None: self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) if self.datasets_reg is not None: - self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd) + self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, + self.sd) if self.network_config is not None: # TODO should we completely switch to LycorisSpecialNetwork? @@ -547,6 +564,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network.force_to(self.device_torch, dtype=dtype) # give network to sd so it can use it self.sd.network = self.network + self.network._update_torch_multiplier() self.network.apply_to( text_encoder, @@ -621,32 +639,12 @@ class BaseSDTrainProcess(BaseTrainProcess): if not params: # set trainable params params = self.embedding.get_trainable_params() + flush() else: - # set them to train or not - if self.train_config.train_unet: - self.sd.unet.requires_grad_(True) - self.sd.unet.train() - else: - self.sd.unet.requires_grad_(False) - self.sd.unet.eval() - if self.train_config.train_text_encoder: - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.requires_grad_(True) - te.train() - else: - self.sd.text_encoder.requires_grad_(True) - self.sd.text_encoder.train() - else: - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.requires_grad_(False) - te.eval() - else: - self.sd.text_encoder.requires_grad_(False) - self.sd.text_encoder.eval() + # set the device state preset before getting params + self.sd.set_device_state(self.train_device_state_preset) params = self.get_params() @@ -729,25 +727,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # zero any gradients optimizer.zero_grad() - self.lr_scheduler.step(self.step_num) - if self.embedding is not None or self.train_config.train_text_encoder: - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.train() - else: - self.sd.text_encoder.train() - else: - if isinstance(self.sd.text_encoder, list): - for te in self.sd.text_encoder: - te.eval() - else: - self.sd.text_encoder.eval() - if self.train_config.train_unet or self.embedding: - self.sd.unet.train() - else: - self.sd.unet.eval() + self.sd.set_device_state(self.train_device_state_preset) flush() # self.step_num = 0 for step in range(self.step_num, self.train_config.steps): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 466b2b97..f23a0eed 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -85,6 +85,7 @@ class TrainConfig: self.batch_size: int = kwargs.get('batch_size', 1) self.dtype: str = kwargs.get('dtype', 'fp32') self.xformers = kwargs.get('xformers', False) + self.sdp = kwargs.get('sdp', False) self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', True) self.min_snr_gamma = kwargs.get('min_snr_gamma', None) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index c0c27141..aadaf541 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -4,7 +4,6 @@ import os import re import sys from typing import List, Optional, Dict, Type, Union - import torch from transformers import CLIPTextModel @@ -46,11 +45,12 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + network: 'LoRASpecialNetwork' = None, parent=None, **kwargs ): """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() + super().__init__(network=network) self.lora_name = lora_name self.scalar = torch.tensor(1.0) @@ -150,7 +150,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): 5. modules_dimとmodules_alphaを指定 (推論用) """ # call the parent of the parent we are replacing (LoRANetwork) init - super(LoRANetwork, self).__init__() + torch.nn.Module.__init__(self) self.lora_dim = lora_dim self.alpha = alpha @@ -163,6 +163,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self._multiplier: float = 1.0 self.is_active: bool = False self._is_normalizing: bool = False + self.torch_multiplier = None # triggers the state updates self.multiplier = multiplier self.is_sdxl = is_sdxl @@ -258,6 +259,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): dropout=dropout, rank_dropout=rank_dropout, module_dropout=module_dropout, + network=self, parent=module, ) loras.append(lora) diff --git a/toolkit/lycoris_special.py b/toolkit/lycoris_special.py index b881f9f2..41bdf719 100644 --- a/toolkit/lycoris_special.py +++ b/toolkit/lycoris_special.py @@ -29,6 +29,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): lora_dim=4, alpha=1, dropout=0., rank_dropout=0., module_dropout=0., use_cp=False, + network: 'LycorisSpecialNetwork' = None, parent=None, **kwargs, ): @@ -36,7 +37,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule): # call super of super torch.nn.Module.__init__(self) # call super of - super().__init__(call_super_init=False) + super().__init__(call_super_init=False, network=network) self.lora_name = lora_name self.lora_dim = lora_dim self.cp = False @@ -170,6 +171,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): if module_dropout is None: module_dropout = 0 + self.torch_multiplier = None + # triggers a tensor update self.multiplier = multiplier self.lora_dim = lora_dim @@ -229,6 +232,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + network=self, parent=module, **kwargs ) @@ -240,6 +244,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + network=self, parent=module, **kwargs ) @@ -249,6 +254,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.conv_lora_dim, self.conv_alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + network=self, parent=module, **kwargs ) @@ -271,6 +277,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.dropout, self.rank_dropout, self.module_dropout, use_cp, parent=module, + network=self, **kwargs ) elif module.__class__.__name__ == 'Conv2d': @@ -281,6 +288,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.lora_dim, self.alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + network=self, parent=module, **kwargs ) @@ -290,6 +298,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork): self.conv_lora_dim, self.conv_alpha, self.dropout, self.rank_dropout, self.module_dropout, use_cp, + network=self, parent=module, **kwargs ) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 8b5b8ca6..ebfd0b66 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -4,10 +4,8 @@ from collections import OrderedDict from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any import torch -from diffusers.utils import is_torch_version from torch import nn -from torch.utils.checkpoint import checkpoint - +import weakref from toolkit.metadata import add_model_hash_to_meta from toolkit.paths import KEYMAPS_ROOT @@ -47,11 +45,13 @@ class ToolkitModuleMixin: def __init__( self: Module, *args, + network: Network, call_super_init: bool = True, **kwargs ): if call_super_init: super().__init__(*args, **kwargs) + self.network_ref: weakref.ref = weakref.ref(network) self.is_checkpointing = False self.is_normalizing = False self.normalize_scaler = 1.0 @@ -125,13 +125,13 @@ class ToolkitModuleMixin: # this may get an additional positional arg or not def forward(self: Module, x, *args, **kwargs): - # diffusers added scale to resnet.. not sure what it does - if self._multiplier is None: - self.set_multiplier(0.0) + if not self.network_ref().is_active: + # network is not active, avoid doing anything + return self.org_forward(x, *args, **kwargs) org_forwarded = self.org_forward(x, *args, **kwargs) lora_output = self._call_forward(x) - multiplier = self._multiplier.clone().detach() + multiplier = self.network_ref().torch_multiplier lora_output_batch_size = lora_output.size(0) multiplier_batch_size = multiplier.size(0) @@ -328,35 +328,52 @@ class ToolkitNetworkMixin: extra_dict = None return extra_dict + def _update_torch_multiplier(self: Network): + # builds a tensor for fast usage in the forward pass of the network modules + # without having to set it in every single module every time it changes + multiplier = self._multiplier + # get first module + first_module = self.get_all_modules()[0] + device = first_module.lora_down.weight.device + dtype = first_module.lora_down.weight.dtype + with torch.no_grad(): + tensor_multiplier = None + if isinstance(multiplier, int) or isinstance(multiplier, float): + tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype) + elif isinstance(multiplier, list): + tensor_list = [] + for m in multiplier: + if isinstance(m, int) or isinstance(m, float): + tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype)) + elif isinstance(m, torch.Tensor): + tensor_list.append(m.clone().detach().to(device, dtype=dtype)) + tensor_multiplier = torch.cat(tensor_list) + elif isinstance(multiplier, torch.Tensor): + tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype) + + self.torch_multiplier = tensor_multiplier.clone().detach() + + @property def multiplier(self) -> Union[float, List[float]]: return self._multiplier @multiplier.setter def multiplier(self, value: Union[float, List[float]]): - # only update if the value has changed + # it takes time to update all the multipliers, so we only do it if the value has changed if self._multiplier == value: return + # if we are setting a single value but have a list, keep the list if every item is the same as value self._multiplier = value - self._update_lora_multiplier() - - def _update_lora_multiplier(self: Network): - if self.is_active: - for lora in self.get_all_modules(): - lora.set_multiplier(self._multiplier) - else: - for lora in self.get_all_modules(): - lora.set_multiplier(0) + self._update_torch_multiplier() # called when the context manager is entered # ie: with network: def __enter__(self: Network): self.is_active = True - self._update_lora_multiplier() def __exit__(self: Network, exc_type, exc_value, tb): self.is_active = False - self._update_lora_multiplier() def force_to(self: Network, device, dtype): self.to(device, dtype) diff --git a/toolkit/progress_bar.py b/toolkit/progress_bar.py index 2707c56e..e42f8086 100644 --- a/toolkit/progress_bar.py +++ b/toolkit/progress_bar.py @@ -6,6 +6,7 @@ class ToolkitProgressBar(tqdm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.paused = False + self.last_time = self._time() def pause(self): if not self.paused: @@ -15,7 +16,9 @@ class ToolkitProgressBar(tqdm): def unpause(self): if self.paused: self.paused = False - self.start_t += self._time() - self.last_time + cur_t = self._time() + self.start_t += cur_t - self.last_time + self.last_print_t = cur_t def update(self, *args, **kwargs): if not self.paused: diff --git a/toolkit/sd_device_states_presets.py b/toolkit/sd_device_states_presets.py new file mode 100644 index 00000000..9ffbd945 --- /dev/null +++ b/toolkit/sd_device_states_presets.py @@ -0,0 +1,59 @@ +import torch +import copy + +empty_preset = { + 'vae': { + 'training': False, + 'device': 'cpu', + }, + 'unet': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + }, + 'text_encoder': { + 'training': False, + 'requires_grad': False, + 'device': 'cpu', + } +} + + +def get_train_sd_device_state_preset ( + device: torch.DeviceObjType, + train_unet: bool = False, + train_text_encoder: bool = False, + cached_latents: bool = False, + train_lora: bool = False, + train_embedding: bool = False, +): + preset = copy.deepcopy(empty_preset) + if not cached_latents: + preset['vae']['device'] = device + + if train_unet: + preset['unet']['training'] = True + preset['unet']['requires_grad'] = True + preset['unet']['device'] = device + else: + preset['unet']['device'] = device + + if train_text_encoder: + preset['text_encoder']['training'] = True + preset['text_encoder']['requires_grad'] = True + preset['text_encoder']['device'] = device + else: + preset['text_encoder']['device'] = device + + if train_embedding: + preset['text_encoder']['training'] = True + preset['text_encoder']['requires_grad'] = True + preset['text_encoder']['training'] = True + preset['unet']['training'] = True + + + if train_lora: + preset['text_encoder']['requires_grad'] = False + preset['unet']['requires_grad'] = False + + return preset diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index bd303eee..77704e48 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -24,7 +24,6 @@ from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers from toolkit.train_tools import get_torch_dtype, apply_noise_offset import torch -from diffusers.schedulers import DDPMScheduler from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline @@ -48,7 +47,7 @@ DO_NOT_TRAIN_WEIGHTS = [ "unet_time_embedding.linear_2.weight", ] -DeviceStatePreset = Literal['cache_latents'] +DeviceStatePreset = Literal['cache_latents', 'generate'] class BlankNetwork: @@ -111,7 +110,7 @@ class StableDiffusion: self.unet: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] - self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] = noise_scheduler + self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler # sdxl stuff self.logit_scale = None @@ -247,7 +246,7 @@ class StableDiffusion: # pipe.unet = prepare_unet_for_training(pipe.unet) self.unet = pipe.unet - self.vae = pipe.vae.to(self.device_torch, dtype=dtype) + self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype) self.vae.eval() self.vae.requires_grad_(False) self.unet.to(self.device_torch, dtype=dtype) @@ -275,26 +274,12 @@ class StableDiffusion: network.is_normalizing = False self.save_device_state() + self.set_device_state_preset('generate') # save current seed state for training rng_state = torch.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None - # handle sdxl text encoder - if isinstance(self.text_encoder, list): - for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))): - encoder.to(self.device_torch) - encoder.eval() - else: - self.text_encoder.to(self.device_torch) - self.text_encoder.eval() - - self.vae.to(self.device_torch) - self.vae.eval() - self.unet.to(self.device_torch) - self.unet.eval() - flush() - noise_scheduler = self.noise_scheduler if sampler is not None: if sampler.startswith("sample_"): # sample_dpmpp_2m @@ -346,6 +331,7 @@ class StableDiffusion: start_multiplier = self.network.multiplier pipeline.to(self.device_torch) + with network: with torch.no_grad(): if self.network is not None: @@ -876,6 +862,7 @@ class StableDiffusion: 'unet': { 'training': self.unet.training, 'device': self.unet.device, + 'requires_grad': self.unet.conv_in.weight.requires_grad, }, } if isinstance(self.text_encoder, list): @@ -884,11 +871,14 @@ class StableDiffusion: self.device_state['text_encoder'].append({ 'training': encoder.training, 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad }) else: self.device_state['text_encoder'] = { 'training': self.text_encoder.training, 'device': self.text_encoder.device, + 'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad } def restore_device_state(self): @@ -910,19 +900,33 @@ class StableDiffusion: else: self.unet.eval() self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) if isinstance(self.text_encoder, list): for i, encoder in enumerate(self.text_encoder): - if state['text_encoder'][i]['training']: - encoder.train() + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_(state['text_encoder'][i]['requires_grad']) else: - encoder.eval() - encoder.to(state['text_encoder'][i]['device']) + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_(state['text_encoder']['requires_grad']) else: if state['text_encoder']['training']: self.text_encoder.train() else: self.text_encoder.eval() self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_(state['text_encoder']['requires_grad']) flush() def set_device_state_preset(self, device_state_preset: DeviceStatePreset): @@ -935,18 +939,22 @@ class StableDiffusion: training_modules = [] if device_state_preset in ['cache_latents']: active_modules = ['vae'] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', 'text_encoder'] state = {} # vae state['vae'] = { 'training': 'vae' in training_modules, 'device': self.device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, } # unet state['unet'] = { 'training': 'unet' in training_modules, 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, } # text encoder @@ -956,11 +964,13 @@ class StableDiffusion: state['text_encoder'].append({ 'training': 'text_encoder' in training_modules, 'device': self.device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, }) else: state['text_encoder'] = { 'training': 'text_encoder' in training_modules, 'device': self.device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, } self.set_device_state(state)