Improve reference slider memory and speed

This commit is contained in:
Jaret Burkett
2023-09-10 18:26:44 -06:00
parent 708b07adb7
commit b5ec8e4eb1
3 changed files with 19 additions and 33 deletions

View File

@@ -312,7 +312,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
denoised_latents = denoised_latents.detach()
flush() # 4.2GB to 3GB on 512x512
# flush() # 4.2GB to 3GB on 512x512
# 4.20 GB RAM for 512x512
anchor_loss_float = None
@@ -369,7 +369,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
del anchor_target_noise
# move anchor back to cpu
anchor.to("cpu")
flush()
# flush()
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
@@ -440,7 +440,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
del target_latents
del offset_neutral
del loss
flush()
# flush()
optimizer.step()
lr_scheduler.step()
@@ -457,7 +457,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
)
# move back to cpu
prompt_pair.to("cpu")
flush()
# flush()
# reset network
self.network.multiplier = 1.0