From 216ab164ce15761896bd49632d54bfa83a061817 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 4 Feb 2025 13:36:34 -0700 Subject: [PATCH] Experimental features and bug fixes --- jobs/process/BaseSDTrainProcess.py | 2 +- toolkit/config_modules.py | 1 + toolkit/models/diffusion_feature_extraction.py | 3 ++- toolkit/models/flux.py | 14 ++++++++------ toolkit/optimizers/adafactor.py | 16 +++++++++------- toolkit/stable_diffusion_model.py | 6 +++++- 6 files changed, 26 insertions(+), 16 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1d8278e7..2183709f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1839,7 +1839,7 @@ class BaseSDTrainProcess(BaseTrainProcess): did_first_flush = False for step in range(start_step_num, self.train_config.steps): if self.train_config.do_paramiter_swapping: - self.optimizer.swap_paramiters() + self.optimizer.optimizer.swap_paramiters() self.timer.start('train_loop') if self.train_config.do_random_cfg: self.train_config.do_cfg = True diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 15dca441..7aa30b73 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -483,6 +483,7 @@ class ModelConfig: self.split_model_over_gpus = kwargs.get("split_model_over_gpus", False) if self.split_model_over_gpus and not self.is_flux: raise ValueError("split_model_over_gpus is only supported with flux models currently") + self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) class EMAConfig: diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 2e8195a3..8c6fd966 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -250,13 +250,14 @@ class DiffusionFeatureExtractor3(nn.Module): bs = noise_pred.shape[0] noise_pred_chunks = torch.chunk(noise_pred, bs) timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) stepped_chunks = [] for idx in range(bs): model_output = noise_pred_chunks[idx] timestep = timestep_chunks[idx] scheduler._step_index = None scheduler._init_step_index(timestep) - sample = noisy_latents.to(torch.float32) + sample = noisy_latent_chunks[idx].to(torch.float32) sigma = scheduler.sigmas[scheduler.step_index] sigma_next = scheduler.sigmas[-1] # use last sigma for final step diff --git a/toolkit/models/flux.py b/toolkit/models/flux.py index 829283e4..c2fb5ac9 100644 --- a/toolkit/models/flux.py +++ b/toolkit/models/flux.py @@ -117,16 +117,18 @@ def split_gpu_single_block_forward( return hidden_state_out -def add_model_gpu_splitter_to_flux(transformer: FluxTransformer2DModel): +def add_model_gpu_splitter_to_flux( + transformer: FluxTransformer2DModel, + # ~ 5 billion for all other params + other_module_params: Optional[int] = 5e9, + # since they are not trainable, multiply by smaller number + other_module_param_count_scale: Optional[float] = 0.3 +): gpu_id_list = [i for i in range(torch.cuda.device_count())] # if len(gpu_id_list) > 2: # raise ValueError("Cannot split to more than 2 GPUs currently.") - - # ~ 5 billion for all other params - other_module_params = 5e9 - # since they are not trainable, multiply by smaller number - other_module_params *= 0.5 + other_module_params *= other_module_param_count_scale # since we are not tuning the total_params = sum(p.numel() for p in transformer.parameters()) + other_module_params diff --git a/toolkit/optimizers/adafactor.py b/toolkit/optimizers/adafactor.py index 00cf06ee..4d97b2cd 100644 --- a/toolkit/optimizers/adafactor.py +++ b/toolkit/optimizers/adafactor.py @@ -108,6 +108,7 @@ class Adafactor(torch.optim.Optimizer): warmup_init=False, do_paramiter_swapping=False, paramiter_swapping_factor=0.1, + stochastic_accumulation=True, ): if lr is not None and relative_step: raise ValueError( @@ -136,13 +137,14 @@ class Adafactor(torch.optim.Optimizer): self.is_stochastic_rounding_accumulation = False # setup stochastic grad accum hooks - for group in self.param_groups: - for param in group['params']: - if param.requires_grad and param.dtype != torch.float32: - self.is_stochastic_rounding_accumulation = True - param.register_post_accumulate_grad_hook( - stochastic_grad_accummulation - ) + if stochastic_accumulation: + for group in self.param_groups: + for param in group['params']: + if param.requires_grad and param.dtype != torch.float32: + self.is_stochastic_rounding_accumulation = True + param.register_post_accumulate_grad_hook( + stochastic_grad_accummulation + ) self.do_paramiter_swapping = do_paramiter_swapping self.paramiter_swapping_factor = paramiter_swapping_factor diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 18ac335d..7bdc586b 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -66,6 +66,7 @@ from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc +from diffusers import FluxFillPipeline if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -560,7 +561,10 @@ class StableDiffusion: ) # hack in model gpu splitter if self.model_config.split_model_over_gpus: - add_model_gpu_splitter_to_flux(transformer) + add_model_gpu_splitter_to_flux( + transformer, + other_module_param_count_scale=self.model_config.split_model_other_module_param_count_scale + ) if not self.low_vram: # for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu