implemented device placement preset system more places. Vastly improved speed on setting network multiplier and activating network. Fixed timing issues on progress bar

This commit is contained in:
Jaret Burkett
2023-09-14 08:31:54 -06:00
parent 4e945917df
commit 569d7464d5
9 changed files with 173 additions and 91 deletions

View File

@@ -29,6 +29,7 @@ class SDTrainer(BaseSDTrainProcess):
else:
# offload it. Already cached
self.sd.vae.to('cpu')
flush()
def hook_train_loop(self, batch):
@@ -110,7 +111,5 @@ class SDTrainer(BaseSDTrainProcess):
loss_dict = OrderedDict(
{'loss': loss.item()}
)
# reset network multiplier
network.multiplier = 1.0
return loss_dict

View File

@@ -7,6 +7,8 @@ from typing import Union
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
import torch.backends.cuda
from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
@@ -21,6 +23,7 @@ from toolkit.progress_bar import ToolkitProgressBar
from toolkit.sampler import get_sampler
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
from jobs.process import BaseTrainProcess
@@ -28,7 +31,6 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
from toolkit.train_tools import get_torch_dtype
import gc
import torch
from tqdm import tqdm
from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
@@ -135,6 +137,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network: Union[Network, None] = None
self.embedding: Union[Embedding, None] = None
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_embedding=self.embed_config is not None,
)
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
gen_img_config_list = []
@@ -477,6 +489,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# if it has it
if hasattr(te, 'enable_xformers_memory_efficient_attention'):
te.enable_xformers_memory_efficient_attention()
if self.train_config.sdp:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@@ -513,7 +529,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
if self.datasets_reg is not None:
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, self.sd)
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
@@ -547,6 +564,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.force_to(self.device_torch, dtype=dtype)
# give network to sd so it can use it
self.sd.network = self.network
self.network._update_torch_multiplier()
self.network.apply_to(
text_encoder,
@@ -621,32 +639,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
flush()
else:
# set them to train or not
if self.train_config.train_unet:
self.sd.unet.requires_grad_(True)
self.sd.unet.train()
else:
self.sd.unet.requires_grad_(False)
self.sd.unet.eval()
if self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.requires_grad_(True)
te.train()
else:
self.sd.text_encoder.requires_grad_(True)
self.sd.text_encoder.train()
else:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.requires_grad_(False)
te.eval()
else:
self.sd.text_encoder.requires_grad_(False)
self.sd.text_encoder.eval()
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
params = self.get_params()
@@ -729,25 +727,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# zero any gradients
optimizer.zero_grad()
self.lr_scheduler.step(self.step_num)
if self.embedding is not None or self.train_config.train_text_encoder:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.train()
else:
self.sd.text_encoder.train()
else:
if isinstance(self.sd.text_encoder, list):
for te in self.sd.text_encoder:
te.eval()
else:
self.sd.text_encoder.eval()
if self.train_config.train_unet or self.embedding:
self.sd.unet.train()
else:
self.sd.unet.eval()
self.sd.set_device_state(self.train_device_state_preset)
flush()
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):

View File

@@ -85,6 +85,7 @@ class TrainConfig:
self.batch_size: int = kwargs.get('batch_size', 1)
self.dtype: str = kwargs.get('dtype', 'fp32')
self.xformers = kwargs.get('xformers', False)
self.sdp = kwargs.get('sdp', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)

View File

@@ -4,7 +4,6 @@ import os
import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from transformers import CLIPTextModel
@@ -46,11 +45,12 @@ class LoRAModule(ToolkitModuleMixin, torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
network: 'LoRASpecialNetwork' = None,
parent=None,
**kwargs
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
super().__init__(network=network)
self.lora_name = lora_name
self.scalar = torch.tensor(1.0)
@@ -150,7 +150,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
5. modules_dimとmodules_alphaを指定 (推論用)
"""
# call the parent of the parent we are replacing (LoRANetwork) init
super(LoRANetwork, self).__init__()
torch.nn.Module.__init__(self)
self.lora_dim = lora_dim
self.alpha = alpha
@@ -163,6 +163,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self._multiplier: float = 1.0
self.is_active: bool = False
self._is_normalizing: bool = False
self.torch_multiplier = None
# triggers the state updates
self.multiplier = multiplier
self.is_sdxl = is_sdxl
@@ -258,6 +259,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
network=self,
parent=module,
)
loras.append(lora)

View File

@@ -29,6 +29,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
lora_dim=4, alpha=1,
dropout=0., rank_dropout=0., module_dropout=0.,
use_cp=False,
network: 'LycorisSpecialNetwork' = None,
parent=None,
**kwargs,
):
@@ -36,7 +37,7 @@ class LoConSpecialModule(ToolkitModuleMixin, LoConModule):
# call super of super
torch.nn.Module.__init__(self)
# call super of
super().__init__(call_super_init=False)
super().__init__(call_super_init=False, network=network)
self.lora_name = lora_name
self.lora_dim = lora_dim
self.cp = False
@@ -170,6 +171,8 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
if module_dropout is None:
module_dropout = 0
self.torch_multiplier = None
# triggers a tensor update
self.multiplier = multiplier
self.lora_dim = lora_dim
@@ -229,6 +232,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
**kwargs
)
@@ -240,6 +244,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
**kwargs
)
@@ -249,6 +254,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
**kwargs
)
@@ -271,6 +277,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
parent=module,
network=self,
**kwargs
)
elif module.__class__.__name__ == 'Conv2d':
@@ -281,6 +288,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.lora_dim, self.alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
**kwargs
)
@@ -290,6 +298,7 @@ class LycorisSpecialNetwork(ToolkitNetworkMixin, LycorisNetwork):
self.conv_lora_dim, self.conv_alpha,
self.dropout, self.rank_dropout, self.module_dropout,
use_cp,
network=self,
parent=module,
**kwargs
)

View File

@@ -4,10 +4,8 @@ from collections import OrderedDict
from typing import Optional, Union, List, Type, TYPE_CHECKING, Dict, Any
import torch
from diffusers.utils import is_torch_version
from torch import nn
from torch.utils.checkpoint import checkpoint
import weakref
from toolkit.metadata import add_model_hash_to_meta
from toolkit.paths import KEYMAPS_ROOT
@@ -47,11 +45,13 @@ class ToolkitModuleMixin:
def __init__(
self: Module,
*args,
network: Network,
call_super_init: bool = True,
**kwargs
):
if call_super_init:
super().__init__(*args, **kwargs)
self.network_ref: weakref.ref = weakref.ref(network)
self.is_checkpointing = False
self.is_normalizing = False
self.normalize_scaler = 1.0
@@ -125,13 +125,13 @@ class ToolkitModuleMixin:
# this may get an additional positional arg or not
def forward(self: Module, x, *args, **kwargs):
# diffusers added scale to resnet.. not sure what it does
if self._multiplier is None:
self.set_multiplier(0.0)
if not self.network_ref().is_active:
# network is not active, avoid doing anything
return self.org_forward(x, *args, **kwargs)
org_forwarded = self.org_forward(x, *args, **kwargs)
lora_output = self._call_forward(x)
multiplier = self._multiplier.clone().detach()
multiplier = self.network_ref().torch_multiplier
lora_output_batch_size = lora_output.size(0)
multiplier_batch_size = multiplier.size(0)
@@ -328,35 +328,52 @@ class ToolkitNetworkMixin:
extra_dict = None
return extra_dict
def _update_torch_multiplier(self: Network):
# builds a tensor for fast usage in the forward pass of the network modules
# without having to set it in every single module every time it changes
multiplier = self._multiplier
# get first module
first_module = self.get_all_modules()[0]
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
with torch.no_grad():
tensor_multiplier = None
if isinstance(multiplier, int) or isinstance(multiplier, float):
tensor_multiplier = torch.tensor((multiplier,)).to(device, dtype=dtype)
elif isinstance(multiplier, list):
tensor_list = []
for m in multiplier:
if isinstance(m, int) or isinstance(m, float):
tensor_list.append(torch.tensor((m,)).to(device, dtype=dtype))
elif isinstance(m, torch.Tensor):
tensor_list.append(m.clone().detach().to(device, dtype=dtype))
tensor_multiplier = torch.cat(tensor_list)
elif isinstance(multiplier, torch.Tensor):
tensor_multiplier = multiplier.clone().detach().to(device, dtype=dtype)
self.torch_multiplier = tensor_multiplier.clone().detach()
@property
def multiplier(self) -> Union[float, List[float]]:
return self._multiplier
@multiplier.setter
def multiplier(self, value: Union[float, List[float]]):
# only update if the value has changed
# it takes time to update all the multipliers, so we only do it if the value has changed
if self._multiplier == value:
return
# if we are setting a single value but have a list, keep the list if every item is the same as value
self._multiplier = value
self._update_lora_multiplier()
def _update_lora_multiplier(self: Network):
if self.is_active:
for lora in self.get_all_modules():
lora.set_multiplier(self._multiplier)
else:
for lora in self.get_all_modules():
lora.set_multiplier(0)
self._update_torch_multiplier()
# called when the context manager is entered
# ie: with network:
def __enter__(self: Network):
self.is_active = True
self._update_lora_multiplier()
def __exit__(self: Network, exc_type, exc_value, tb):
self.is_active = False
self._update_lora_multiplier()
def force_to(self: Network, device, dtype):
self.to(device, dtype)

View File

@@ -6,6 +6,7 @@ class ToolkitProgressBar(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.paused = False
self.last_time = self._time()
def pause(self):
if not self.paused:
@@ -15,7 +16,9 @@ class ToolkitProgressBar(tqdm):
def unpause(self):
if self.paused:
self.paused = False
self.start_t += self._time() - self.last_time
cur_t = self._time()
self.start_t += cur_t - self.last_time
self.last_print_t = cur_t
def update(self, *args, **kwargs):
if not self.paused:

View File

@@ -0,0 +1,59 @@
import torch
import copy
empty_preset = {
'vae': {
'training': False,
'device': 'cpu',
},
'unet': {
'training': False,
'requires_grad': False,
'device': 'cpu',
},
'text_encoder': {
'training': False,
'requires_grad': False,
'device': 'cpu',
}
}
def get_train_sd_device_state_preset (
device: torch.DeviceObjType,
train_unet: bool = False,
train_text_encoder: bool = False,
cached_latents: bool = False,
train_lora: bool = False,
train_embedding: bool = False,
):
preset = copy.deepcopy(empty_preset)
if not cached_latents:
preset['vae']['device'] = device
if train_unet:
preset['unet']['training'] = True
preset['unet']['requires_grad'] = True
preset['unet']['device'] = device
else:
preset['unet']['device'] = device
if train_text_encoder:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['device'] = device
else:
preset['text_encoder']['device'] = device
if train_embedding:
preset['text_encoder']['training'] = True
preset['text_encoder']['requires_grad'] = True
preset['text_encoder']['training'] = True
preset['unet']['training'] = True
if train_lora:
preset['text_encoder']['requires_grad'] = False
preset['unet']['requires_grad'] = False
return preset

View File

@@ -24,7 +24,6 @@ from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import torch
from diffusers.schedulers import DDPMScheduler
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
@@ -48,7 +47,7 @@ DO_NOT_TRAIN_WEIGHTS = [
"unet_time_embedding.linear_2.weight",
]
DeviceStatePreset = Literal['cache_latents']
DeviceStatePreset = Literal['cache_latents', 'generate']
class BlankNetwork:
@@ -111,7 +110,7 @@ class StableDiffusion:
self.unet: Union[None, 'UNet2DConditionModel']
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers', 'DDPMScheduler'] = noise_scheduler
self.noise_scheduler: Union[None, 'KarrasDiffusionSchedulers'] = noise_scheduler
# sdxl stuff
self.logit_scale = None
@@ -247,7 +246,7 @@ class StableDiffusion:
# pipe.unet = prepare_unet_for_training(pipe.unet)
self.unet = pipe.unet
self.vae = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype)
self.vae.eval()
self.vae.requires_grad_(False)
self.unet.to(self.device_torch, dtype=dtype)
@@ -275,26 +274,12 @@ class StableDiffusion:
network.is_normalizing = False
self.save_device_state()
self.set_device_state_preset('generate')
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
# handle sdxl text encoder
if isinstance(self.text_encoder, list):
for encoder, i in zip(self.text_encoder, range(len(self.text_encoder))):
encoder.to(self.device_torch)
encoder.eval()
else:
self.text_encoder.to(self.device_torch)
self.text_encoder.eval()
self.vae.to(self.device_torch)
self.vae.eval()
self.unet.to(self.device_torch)
self.unet.eval()
flush()
noise_scheduler = self.noise_scheduler
if sampler is not None:
if sampler.startswith("sample_"): # sample_dpmpp_2m
@@ -346,6 +331,7 @@ class StableDiffusion:
start_multiplier = self.network.multiplier
pipeline.to(self.device_torch)
with network:
with torch.no_grad():
if self.network is not None:
@@ -876,6 +862,7 @@ class StableDiffusion:
'unet': {
'training': self.unet.training,
'device': self.unet.device,
'requires_grad': self.unet.conv_in.weight.requires_grad,
},
}
if isinstance(self.text_encoder, list):
@@ -884,11 +871,14 @@ class StableDiffusion:
self.device_state['text_encoder'].append({
'training': encoder.training,
'device': encoder.device,
# todo there has to be a better way to do this
'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad
})
else:
self.device_state['text_encoder'] = {
'training': self.text_encoder.training,
'device': self.text_encoder.device,
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
}
def restore_device_state(self):
@@ -910,19 +900,33 @@ class StableDiffusion:
else:
self.unet.eval()
self.unet.to(state['unet']['device'])
if state['unet']['requires_grad']:
self.unet.requires_grad_(True)
else:
self.unet.requires_grad_(False)
if isinstance(self.text_encoder, list):
for i, encoder in enumerate(self.text_encoder):
if state['text_encoder'][i]['training']:
encoder.train()
if isinstance(state['text_encoder'], list):
if state['text_encoder'][i]['training']:
encoder.train()
else:
encoder.eval()
encoder.to(state['text_encoder'][i]['device'])
encoder.requires_grad_(state['text_encoder'][i]['requires_grad'])
else:
encoder.eval()
encoder.to(state['text_encoder'][i]['device'])
if state['text_encoder']['training']:
encoder.train()
else:
encoder.eval()
encoder.to(state['text_encoder']['device'])
encoder.requires_grad_(state['text_encoder']['requires_grad'])
else:
if state['text_encoder']['training']:
self.text_encoder.train()
else:
self.text_encoder.eval()
self.text_encoder.to(state['text_encoder']['device'])
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
flush()
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
@@ -935,18 +939,22 @@ class StableDiffusion:
training_modules = []
if device_state_preset in ['cache_latents']:
active_modules = ['vae']
if device_state_preset in ['generate']:
active_modules = ['vae', 'unet', 'text_encoder']
state = {}
# vae
state['vae'] = {
'training': 'vae' in training_modules,
'device': self.device_torch if 'vae' in active_modules else 'cpu',
'requires_grad': 'vae' in training_modules,
}
# unet
state['unet'] = {
'training': 'unet' in training_modules,
'device': self.device_torch if 'unet' in active_modules else 'cpu',
'requires_grad': 'unet' in training_modules,
}
# text encoder
@@ -956,11 +964,13 @@ class StableDiffusion:
state['text_encoder'].append({
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
'requires_grad': 'text_encoder' in training_modules,
})
else:
state['text_encoder'] = {
'training': 'text_encoder' in training_modules,
'device': self.device_torch if 'text_encoder' in active_modules else 'cpu',
'requires_grad': 'text_encoder' in training_modules,
}
self.set_device_state(state)