Improve reference slider memory and speed

This commit is contained in:
Jaret Burkett
2023-09-10 18:26:44 -06:00
parent 708b07adb7
commit b5ec8e4eb1
3 changed files with 19 additions and 33 deletions

View File

@@ -154,38 +154,27 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
timesteps = torch.cat([timesteps, timesteps], dim=0)
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
flush()
loss_float = None
loss_mirror_float = None
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
# text encoding
embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
if self.model_config.is_xl:
# todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
network_multiplier_list = network_multiplier
noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
noise_list = torch.chunk(noise, 2, dim=0)
timesteps_list = torch.chunk(timesteps, 2, dim=0)
conditional_embeds_list = split_prompt_embeds(conditional_embeds)
else:
network_multiplier_list = [network_multiplier]
noisy_latent_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditional_embeds_list = [conditional_embeds]
# if self.model_config.is_xl:
# # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
# network_multiplier_list = network_multiplier
# noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
# noise_list = torch.chunk(noise, 2, dim=0)
# timesteps_list = torch.chunk(timesteps, 2, dim=0)
# conditional_embeds_list = split_prompt_embeds(conditional_embeds)
# else:
network_multiplier_list = [network_multiplier]
noisy_latent_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditional_embeds_list = [conditional_embeds]
losses = []
# allow to chunk it out to save vram
@@ -213,20 +202,17 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# todo add snr gamma here
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean() * loss_jitter_multiplier
loss_slide_float = loss.item()
loss_float = loss.item()
losses.append(loss_float)
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()