diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index fafeff12..228f9491 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -56,8 +56,9 @@ import gc from tqdm import tqdm from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs from toolkit.logging import create_logger +from diffusers import FluxTransformer2DModel def flush(): torch.cuda.empty_cache() @@ -201,6 +202,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.named_lora = True self.snr_gos: Union[LearnableSNRGamma, None] = None self.ema: ExponentialMovingAverage = None + + validate_configs(self.train_config, self.model_config, self.save_config) def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -587,9 +590,10 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): self.logger.start() - def ensure_params_requires_grad(self): - # get param groups - # for group in self.optimizer.param_groups: + def ensure_params_requires_grad(self, force=False): + if self.train_config.do_paramiter_swapping and not force: + # the optimizer will handle this if we are not forcing + return for group in self.params: for param in group['params']: if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter @@ -1278,6 +1282,24 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) + + # # check if we have sage and is flux + # if self.sd.is_flux: + # # try_to_activate_sage_attn() + # try: + # from sageattention import sageattn + # from toolkit.models.flux_sage_attn import FluxSageAttnProcessor2_0 + # model: FluxTransformer2DModel = self.sd.unet + # # enable sage attention on each block + # for block in model.transformer_blocks: + # processor = FluxSageAttnProcessor2_0() + # block.attn.set_processor(processor) + # for block in model.single_transformer_blocks: + # processor = FluxSageAttnProcessor2_0() + # block.attn.set_processor(processor) + + # except ImportError: + # print("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: if self.sd.is_flux: @@ -1539,10 +1561,15 @@ class BaseSDTrainProcess(BaseTrainProcess): optimizer_type = self.train_config.optimizer.lower() # esure params require grad - self.ensure_params_requires_grad() + self.ensure_params_requires_grad(force=True) optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr, optimizer_params=self.train_config.optimizer_params) self.optimizer = optimizer + + # set it to do paramiter swapping + if self.train_config.do_paramiter_swapping: + # only works for adafactor, but it should have thrown an error prior to this otherwise + self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor) # check if it exists optimizer_state_filename = f'optimizer.pt' @@ -1648,7 +1675,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # torch.compile(self.sd.unet, dynamic=True) # make sure all params require grad - self.ensure_params_requires_grad() + self.ensure_params_requires_grad(force=True) ################################################################### @@ -1659,6 +1686,8 @@ class BaseSDTrainProcess(BaseTrainProcess): start_step_num = self.step_num did_first_flush = False for step in range(start_step_num, self.train_config.steps): + if self.train_config.do_paramiter_swapping: + self.optimizer.swap_paramiters() self.timer.start('train_loop') if self.train_config.do_random_cfg: self.train_config.do_cfg = True @@ -1738,6 +1767,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### + loss_dict = self.hook_train_loop(batch_list) self.timer.stop('train_loop') if not did_first_flush: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5eb1a37c..5d9dc92b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -389,6 +389,10 @@ class TrainConfig: # will cache a blank prompt or the trigger word, and unload the text encoder to cpu # will make training faster and use less vram self.unload_text_encoder = kwargs.get('unload_text_encoder', False) + # for swapping which parameters are trained during training + self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False) + # 0.1 is 10% of the parameters active at a time lower is less vram, higher is more + self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1) class ModelConfig: @@ -898,4 +902,16 @@ class GenerateImageConfig: if self.logger is None: return - self.logger.log_image(image, count, self.prompt) \ No newline at end of file + self.logger.log_image(image, count, self.prompt) + + +def validate_configs( + train_config: TrainConfig, + model_config: ModelConfig, + save_config: SaveConfig, +): + if model_config.is_flux: + if save_config.save_format != 'diffusers': + # make it diffusers + save_config.save_format = 'diffusers' + \ No newline at end of file diff --git a/toolkit/models/flux_sage_attn.py b/toolkit/models/flux_sage_attn.py new file mode 100644 index 00000000..930a1700 --- /dev/null +++ b/toolkit/models/flux_sage_attn.py @@ -0,0 +1,94 @@ +from typing import Optional +from diffusers.models.attention_processor import Attention +import torch +import torch.nn.functional as F + + +class FluxSageAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + from sageattention import sageattn + + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = sageattn(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states \ No newline at end of file diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index b98f4590..2f1a8997 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -3,6 +3,7 @@ from typing import List import torch from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation from optimum.quanto import QBytesTensor +import random class Adafactor(torch.optim.Optimizer): @@ -105,6 +106,8 @@ class Adafactor(torch.optim.Optimizer): scale_parameter=True, relative_step=True, warmup_init=False, + do_paramiter_swapping=False, + paramiter_swapping_factor=0.1, ): if lr is not None and relative_step: raise ValueError( @@ -140,6 +143,49 @@ class Adafactor(torch.optim.Optimizer): param.register_post_accumulate_grad_hook( stochastic_grad_accummulation ) + + self.do_paramiter_swapping = do_paramiter_swapping + self.paramiter_swapping_factor = paramiter_swapping_factor + self._total_paramiter_size = 0 + # count total paramiters + for group in self.param_groups: + for param in group['params']: + self._total_paramiter_size += torch.numel(param) + # pretty print total paramiters with comma seperation + print(f"Total training paramiters: {self._total_paramiter_size:,}") + + # needs to be enabled to count paramiters + if self.do_paramiter_swapping: + self.enable_paramiter_swapping(self.paramiter_swapping_factor) + + + def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): + self.do_paramiter_swapping = True + self.paramiter_swapping_factor = paramiter_swapping_factor + # call it an initial time + self.swap_paramiters() + + def swap_paramiters(self): + all_params = [] + # deactivate all paramiters + for group in self.param_groups: + for param in group['params']: + param.requires_grad_(False) + # remove any grad + param.grad = None + all_params.append(param) + # shuffle all paramiters + random.shuffle(all_params) + + # keep activating paramiters until we are going to go over the target paramiters + target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor) + total_paramiters = 0 + for param in all_params: + total_paramiters += torch.numel(param) + if total_paramiters >= target_paramiters: + break + else: + param.requires_grad_(True) @staticmethod def _get_lr(param_group, param_state): @@ -209,7 +255,7 @@ class Adafactor(torch.optim.Optimizer): for group in self.param_groups: for p in group["params"]: - if p.grad is None: + if p.grad is None or not p.requires_grad: continue grad = p.grad