Improve reference slider memory and speed

This commit is contained in:
Jaret Burkett
2023-09-10 18:26:44 -06:00
parent 708b07adb7
commit b5ec8e4eb1
3 changed files with 19 additions and 33 deletions

View File

@@ -154,38 +154,27 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
timesteps = torch.cat([timesteps, timesteps], dim=0)
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
flush()
loss_float = None
loss_mirror_float = None
self.optimizer.zero_grad()
noisy_latents.requires_grad = False
# if training text encoder enable grads, else do context of no grad
with torch.set_grad_enabled(self.train_config.train_text_encoder):
# text encoding
embedding_list = []
# embed the prompts
for prompt in prompts:
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
embedding_list.append(embedding)
conditional_embeds = concat_prompt_embeds(embedding_list)
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
if self.model_config.is_xl:
# todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
network_multiplier_list = network_multiplier
noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
noise_list = torch.chunk(noise, 2, dim=0)
timesteps_list = torch.chunk(timesteps, 2, dim=0)
conditional_embeds_list = split_prompt_embeds(conditional_embeds)
else:
network_multiplier_list = [network_multiplier]
noisy_latent_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditional_embeds_list = [conditional_embeds]
# if self.model_config.is_xl:
# # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
# network_multiplier_list = network_multiplier
# noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
# noise_list = torch.chunk(noise, 2, dim=0)
# timesteps_list = torch.chunk(timesteps, 2, dim=0)
# conditional_embeds_list = split_prompt_embeds(conditional_embeds)
# else:
network_multiplier_list = [network_multiplier]
noisy_latent_list = [noisy_latents]
noise_list = [noise]
timesteps_list = [timesteps]
conditional_embeds_list = [conditional_embeds]
losses = []
# allow to chunk it out to save vram
@@ -213,20 +202,17 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
# todo add snr gamma here
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean() * loss_jitter_multiplier
loss_slide_float = loss.item()
loss_float = loss.item()
losses.append(loss_float)
# back propagate loss to free ram
loss.backward()
flush()
# apply gradients
optimizer.step()

View File

@@ -312,7 +312,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = denoised_latents.detach()
flush() # 4.2GB to 3GB on 512x512
# flush() # 4.2GB to 3GB on 512x512
# 4.20 GB RAM for 512x512
anchor_loss_float = None
@@ -369,7 +369,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
del anchor_target_noise
# move anchor back to cpu
anchor.to("cpu")
flush()
# flush()
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
@@ -440,7 +440,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
del target_latents
del offset_neutral
del loss
flush()
# flush()
optimizer.step()
lr_scheduler.step()
@@ -457,7 +457,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
# move back to cpu
prompt_pair.to("cpu")
flush()
# flush()
# reset network
self.network.multiplier = 1.0

View File

@@ -35,7 +35,7 @@ resolutions_1024: List[BucketResolution] = [
{"width": 960, "height": 1024},
{"width": 960, "height": 1088},
{"width": 896, "height": 1088},
{"width": 896, "height": 1152},
{"width": 896, "height": 1152}, # 2:3
{"width": 832, "height": 1152},
{"width": 832, "height": 1216},
{"width": 768, "height": 1280},