mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improve reference slider memory and speed
This commit is contained in:
@@ -154,38 +154,27 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
timesteps = torch.cat([timesteps, timesteps], dim=0)
|
||||
network_multiplier = [network_pos_weight * 1.0, network_neg_weight * -1.0]
|
||||
|
||||
flush()
|
||||
|
||||
loss_float = None
|
||||
loss_mirror_float = None
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
noisy_latents.requires_grad = False
|
||||
|
||||
# if training text encoder enable grads, else do context of no grad
|
||||
with torch.set_grad_enabled(self.train_config.train_text_encoder):
|
||||
# text encoding
|
||||
embedding_list = []
|
||||
# embed the prompts
|
||||
for prompt in prompts:
|
||||
embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
||||
embedding_list.append(embedding)
|
||||
conditional_embeds = concat_prompt_embeds(embedding_list)
|
||||
conditional_embeds = self.sd.encode_prompt(prompts).to(self.device_torch, dtype=dtype)
|
||||
conditional_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds])
|
||||
|
||||
if self.model_config.is_xl:
|
||||
# todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
|
||||
network_multiplier_list = network_multiplier
|
||||
noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
|
||||
noise_list = torch.chunk(noise, 2, dim=0)
|
||||
timesteps_list = torch.chunk(timesteps, 2, dim=0)
|
||||
conditional_embeds_list = split_prompt_embeds(conditional_embeds)
|
||||
else:
|
||||
network_multiplier_list = [network_multiplier]
|
||||
noisy_latent_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditional_embeds_list = [conditional_embeds]
|
||||
# if self.model_config.is_xl:
|
||||
# # todo also allow for setting this for low ram in general, but sdxl spikes a ton on back prop
|
||||
# network_multiplier_list = network_multiplier
|
||||
# noisy_latent_list = torch.chunk(noisy_latents, 2, dim=0)
|
||||
# noise_list = torch.chunk(noise, 2, dim=0)
|
||||
# timesteps_list = torch.chunk(timesteps, 2, dim=0)
|
||||
# conditional_embeds_list = split_prompt_embeds(conditional_embeds)
|
||||
# else:
|
||||
network_multiplier_list = [network_multiplier]
|
||||
noisy_latent_list = [noisy_latents]
|
||||
noise_list = [noise]
|
||||
timesteps_list = [timesteps]
|
||||
conditional_embeds_list = [conditional_embeds]
|
||||
|
||||
losses = []
|
||||
# allow to chunk it out to save vram
|
||||
@@ -213,20 +202,17 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
# todo add snr gamma here
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean() * loss_jitter_multiplier
|
||||
loss_slide_float = loss.item()
|
||||
|
||||
loss_float = loss.item()
|
||||
losses.append(loss_float)
|
||||
|
||||
# back propagate loss to free ram
|
||||
loss.backward()
|
||||
flush()
|
||||
|
||||
# apply gradients
|
||||
optimizer.step()
|
||||
|
||||
@@ -312,7 +312,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
|
||||
denoised_latents = denoised_latents.detach()
|
||||
|
||||
flush() # 4.2GB to 3GB on 512x512
|
||||
# flush() # 4.2GB to 3GB on 512x512
|
||||
|
||||
# 4.20 GB RAM for 512x512
|
||||
anchor_loss_float = None
|
||||
@@ -369,7 +369,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
del anchor_target_noise
|
||||
# move anchor back to cpu
|
||||
anchor.to("cpu")
|
||||
flush()
|
||||
# flush()
|
||||
|
||||
prompt_pair_chunks = split_prompt_pairs(prompt_pair, self.prompt_chunk_size)
|
||||
assert len(prompt_pair_chunks) == len(denoised_latent_chunks)
|
||||
@@ -440,7 +440,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
del target_latents
|
||||
del offset_neutral
|
||||
del loss
|
||||
flush()
|
||||
# flush()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -457,7 +457,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
||||
)
|
||||
# move back to cpu
|
||||
prompt_pair.to("cpu")
|
||||
flush()
|
||||
# flush()
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
@@ -35,7 +35,7 @@ resolutions_1024: List[BucketResolution] = [
|
||||
{"width": 960, "height": 1024},
|
||||
{"width": 960, "height": 1088},
|
||||
{"width": 896, "height": 1088},
|
||||
{"width": 896, "height": 1152},
|
||||
{"width": 896, "height": 1152}, # 2:3
|
||||
{"width": 832, "height": 1152},
|
||||
{"width": 832, "height": 1216},
|
||||
{"width": 768, "height": 1280},
|
||||
|
||||
Reference in New Issue
Block a user