Tons of bug fixes and improvements to special training. Fixed slider training.

This commit is contained in:
Jaret Burkett
2023-12-09 16:38:10 -07:00
parent eaec2f5a52
commit eaa0fb6253
9 changed files with 639 additions and 74 deletions

View File

@@ -371,7 +371,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
# ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
1, self.train_config.max_denoising_steps - 1, (1,)
).item()
# get noise
@@ -389,7 +389,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
assert not self.network.is_active
self.sd.unet.eval()
# pass the multiplier list to the network
self.network.multiplier = prompt_pair.multiplier_list
# double up since we are doing cfg
self.network.multiplier = prompt_pair.multiplier_list + prompt_pair.multiplier_list
denoised_latents = self.sd.diffuse_some_steps(
latents, # pass simple noise latents
train_tools.concat_prompt_embeddings(
@@ -507,7 +508,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
for anchor_chunk, denoised_latent_chunk, anchor_target_noise_chunk in zip(
anchor_chunks, denoised_latent_chunks, anchor_target_noise_chunks
):
self.network.multiplier = anchor_chunk.multiplier_list
self.network.multiplier = anchor_chunk.multiplier_list + anchor_chunk.multiplier_list
anchor_pred_noise = get_noise_pred(
anchor_chunk.neg_prompt, anchor_chunk.prompt, 1, current_timestep, denoised_latent_chunk
@@ -582,7 +583,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
mask_multiplier_chunks,
unmasked_target_chunks
):
self.network.multiplier = prompt_pair_chunk.multiplier_list
self.network.multiplier = prompt_pair_chunk.multiplier_list + prompt_pair_chunk.multiplier_list
target_latents = get_noise_pred(
prompt_pair_chunk.positive_target,
prompt_pair_chunk.target_class,
@@ -611,6 +612,7 @@ class TrainSliderProcess(BaseSDTrainProcess):
offset_neutral = neutral_latents_chunk
# offsets are already adjusted on a per-batch basis
offset_neutral += offset
offset_neutral = offset_neutral.detach().requires_grad_(False)
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")