From b5ec8e4eb1d09e663a4d3fce440e7ce468d72835 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 10 Sep 2023 18:26:44 -0600 Subject: [PATCH] Improve reference slider memory and speed --- .../ImageReferenceSliderTrainerProcess.py | 42 +++++++------------ jobs/process/TrainSliderProcess.py | 8 ++-- toolkit/buckets.py | 2 +- 3 files changed, 19 insertions(+), 33 deletions(-) diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index 99c41e0d..ea5413fe 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -154,38 +154,27 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): timesteps = torch.cat([timesteps, timesteps], dim=0) network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0] - flush() - - loss_float = None - loss_mirror_float = None - self.optimizer.zero_grad() noisy_latents.requires_grad = False # if training text encoder enable grads, else do context of no grad with torch.set_grad_enabled(self.train_config.train_text_encoder): - # text encoding - embedding_list = [] - # embed the prompts - for prompt in prompts: - embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) - embedding_list.append(embedding) - conditional_embeds = concat_prompt_embeds(embedding_list) + conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype) conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) - if self.model_config.is_xl: - # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop - network_multiplier_list = network_multiplier - noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) - noise_list = torch.chunk(noise, 2, dim=0) - timesteps_list = torch.chunk(timesteps, 2, dim=0) - conditional_embeds_list = split_prompt_embeds(conditional_embeds) - else: - network_multiplier_list = [network_multiplier] - noisy_latent_list = [noisy_latents] - noise_list = [noise] - timesteps_list = [timesteps] - conditional_embeds_list = [conditional_embeds] + # if self.model_config.is_xl: + # # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop + # network_multiplier_list = network_multiplier + # noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0) + # noise_list = torch.chunk(noise, 2, dim=0) + # timesteps_list = torch.chunk(timesteps, 2, dim=0) + # conditional_embeds_list = split_prompt_embeds(conditional_embeds) + # else: + network_multiplier_list = [network_multiplier] + noisy_latent_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditional_embeds_list = [conditional_embeds] losses = [] # allow to chunk it out to save vram @@ -213,20 +202,17 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - # todo add snr gamma here if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: # add min_snr_gamma loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() * loss_jitter_multiplier - loss_slide_float = loss.item() loss_float = loss.item() losses.append(loss_float) # back propagate loss to free ram loss.backward() - flush() # apply gradients optimizer.step() diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 6e9aae48..dd91b155 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -312,7 +312,7 @@ class TrainSliderProcess(BaseSDTrainProcess): denoised_latents = denoised_latents.detach() - flush() # 4.2GB to 3GB on 512x512 + # flush() # 4.2GB to 3GB on 512x512 # 4.20 GB RAM for 512x512 anchor_loss_float = None @@ -369,7 +369,7 @@ class TrainSliderProcess(BaseSDTrainProcess): del anchor_target_noise # move anchor back to cpu anchor.to("cpu") - flush() + # flush() prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size) assert len(prompt_pair_chunks) == len(denoised_latent_chunks) @@ -440,7 +440,7 @@ class TrainSliderProcess(BaseSDTrainProcess): del target_latents del offset_neutral del loss - flush() + # flush() optimizer.step() lr_scheduler.step() @@ -457,7 +457,7 @@ class TrainSliderProcess(BaseSDTrainProcess): ) # move back to cpu prompt_pair.to("cpu") - flush() + # flush() # reset network self.network.multiplier = 1.0 diff --git a/toolkit/buckets.py b/toolkit/buckets.py index 524208a3..3711ea7e 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -35,7 +35,7 @@ resolutions_1024: List[BucketResolution] = [ {"width": 960, "height": 1024}, {"width": 960, "height": 1088}, {"width": 896, "height": 1088}, - {"width": 896, "height": 1152}, + {"width": 896, "height": 1152}, # 2:3 {"width": 832, "height": 1152}, {"width": 832, "height": 1216}, {"width": 768, "height": 1280},