From 1cc663a66466040b4822c4d896484411df5d1b53 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 17 Jun 2025 07:37:41 -0600 Subject: [PATCH] Performance optimizations for pre processing the batch --- README.md | 10 ++ extensions_built_in/sd_trainer/SDTrainer.py | 138 ++++++++++---------- jobs/process/BaseSDTrainProcess.py | 62 ++++----- toolkit/stable_diffusion_model.py | 9 ++ 4 files changed, 120 insertions(+), 99 deletions(-) diff --git a/README.md b/README.md index fa660c6a..5b4ae070 100644 --- a/README.md +++ b/README.md @@ -417,6 +417,16 @@ Everything else should work the same including layer targeting. ## Updates +### June 17, 2024 +- Performance optimizations for batch preparation + +### June 16, 2024 +- Hide control images in the UI when viewing datasets +- WIP on mean flow loss + +### June 12, 2024 +- Fixed issue that resulted in blank captions in the dataloader + ### June 10, 2024 - Decided to keep track up updates in the readme - Added support for SDXL in the UI diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a0de8cd1..479b4141 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1009,76 +1009,76 @@ class SDTrainer(BaseSDTrainProcess): return loss def train_single_accumulation(self, batch: DataLoaderBatchDTO): - self.timer.start('preprocess_batch') - if isinstance(self.adapter, CustomAdapter): - batch = self.adapter.edit_batch_raw(batch) - batch = self.preprocess_batch(batch) - if isinstance(self.adapter, CustomAdapter): - batch = self.adapter.edit_batch_processed(batch) - dtype = get_torch_dtype(self.train_config.dtype) - # sanity check - if self.sd.vae.dtype != self.sd.vae_torch_dtype: - self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) - if isinstance(self.sd.text_encoder, list): - for encoder in self.sd.text_encoder: - if encoder.dtype != self.sd.te_torch_dtype: - encoder.to(self.sd.te_torch_dtype) - else: - if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: - self.sd.text_encoder.to(self.sd.te_torch_dtype) - - noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) - if self.train_config.do_cfg or self.train_config.do_random_cfg: - # pick random negative prompts - if self.negative_prompt_pool is not None: - negative_prompts = [] - for i in range(noisy_latents.shape[0]): - num_neg = random.randint(1, self.train_config.max_negative_prompts) - this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] - this_neg_prompt = ', '.join(this_neg_prompts) - negative_prompts.append(this_neg_prompt) - self.batch_negative_prompt = negative_prompts - else: - self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] - - if self.adapter and isinstance(self.adapter, CustomAdapter): - # condition the prompt - # todo handle more than one adapter image - conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) - - network_weight_list = batch.get_network_weight_list() - if self.train_config.single_item_batching: - network_weight_list = network_weight_list + network_weight_list - - has_adapter_img = batch.control_tensor is not None - has_clip_image = batch.clip_image_tensor is not None - has_clip_image_embeds = batch.clip_image_embeds is not None - # force it to be true if doing regs as we handle those differently - if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): - has_clip_image = True - if self._clip_image_embeds_unconditional is not None: - has_clip_image_embeds = True # we are caching embeds, handle that differently - has_clip_image = False - - if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: - raise ValueError( - "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") - - match_adapter_assist = False - - # check if we are matching the adapter assistant - if self.assistant_adapter: - if self.train_config.match_adapter_chance == 1.0: - match_adapter_assist = True - elif self.train_config.match_adapter_chance > 0.0: - match_adapter_assist = torch.rand( - (1,), device=self.device_torch, dtype=dtype - ) < self.train_config.match_adapter_chance - - self.timer.stop('preprocess_batch') - - is_reg = False with torch.no_grad(): + self.timer.start('preprocess_batch') + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_raw(batch) + batch = self.preprocess_batch(batch) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_processed(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) for idx, file_item in enumerate(batch.file_items): if file_item.is_reg: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1e0393c0..313285fc 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -921,7 +921,10 @@ class BaseSDTrainProcess(BaseTrainProcess): noise = self.get_consistent_noise(latents, batch, dtype=dtype) else: if hasattr(self.sd, 'get_latent_noise_from_latents'): - noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype) + noise = self.sd.get_latent_noise_from_latents( + latents, + noise_offset=self.train_config.noise_offset + ).to(self.device_torch, dtype=dtype) else: # get noise noise = self.sd.get_latent_noise( @@ -931,17 +934,6 @@ class BaseSDTrainProcess(BaseTrainProcess): batch_size=batch_size, noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) - - # if self.train_config.random_noise_shift > 0.0: - # # get random noise -1 to 1 - # noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, - # dtype=noise.dtype) * 2 - 1 - - # # multiply by shift amount - # noise_shift *= self.train_config.random_noise_shift - - # # add to noise - # noise += noise_shift if self.train_config.blended_blur_noise: noise = get_blended_blur_noise( @@ -1085,19 +1077,20 @@ class BaseSDTrainProcess(BaseTrainProcess): # we determine noise from the differential of the latents unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) - batch_size = len(batch.file_items) - min_noise_steps = self.train_config.min_denoising_steps - max_noise_steps = self.train_config.max_denoising_steps - if self.model_config.refiner_name_or_path is not None: - # if we are not training the unet, then we are only doing refiner and do not need to double up - if self.train_config.train_unet: - max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) - do_double = True - else: - min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) - do_double = False + with self.timer('prepare_scheduler'): + + batch_size = len(batch.file_items) + min_noise_steps = self.train_config.min_denoising_steps + max_noise_steps = self.train_config.max_denoising_steps + if self.model_config.refiner_name_or_path is not None: + # if we are not training the unet, then we are only doing refiner and do not need to double up + if self.train_config.train_unet: + max_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = True + else: + min_noise_steps = round(self.train_config.max_denoising_steps * self.model_config.refiner_start_at) + do_double = False - with self.timer('prepare_noise'): num_train_timesteps = self.train_config.num_train_timesteps if self.train_config.noise_scheduler in ['custom_lcm']: @@ -1144,6 +1137,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch ) + with self.timer('prepare_timesteps_indices'): content_or_style = self.train_config.content_or_style if is_reg: @@ -1193,20 +1187,26 @@ class BaseSDTrainProcess(BaseTrainProcess): timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps else: # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here + min_idx = min_noise_steps + 1 + max_idx = max_noise_steps - 1 + if self.train_config.noise_scheduler == 'flowmatch': + # flowmatch uses indices, so we need to use indices + min_idx = 0 + max_idx = max_noise_steps - 1 timestep_indices = torch.randint( - min_noise_steps + 1, - max_noise_steps - 1, + min_idx, + max_idx, (batch_size,), device=self.device_torch ) timestep_indices = timestep_indices.long() else: raise ValueError(f"Unknown content_or_style {content_or_style}") - + with self.timer('convert_timestep_indices_to_timesteps'): # convert the timestep_indices to a timestep - timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] - timesteps = torch.stack(timesteps, dim=0) - + timesteps = self.sd.noise_scheduler.timesteps[timestep_indices.long()] + + with self.timer('prepare_noise'): # get noise noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps) @@ -1240,6 +1240,8 @@ class BaseSDTrainProcess(BaseTrainProcess): device=noise.device, dtype=noise.dtype ) * self.train_config.random_noise_multiplier + + with self.timer('make_noisy_latents'): noise = noise * noise_multiplier diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 068e747e..4bffc816 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1763,6 +1763,15 @@ class StableDiffusion: ) noise = apply_noise_offset(noise, noise_offset) return noise + + def get_latent_noise_from_latents( + self, + latents: torch.Tensor, + noise_offset=0.0 + ): + noise = torch.randn_like(latents) + noise = apply_noise_offset(noise, noise_offset) + return noise def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False): VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)