More fixes for noise schedules and fixed targeted guidance inverted masked prior

This commit is contained in:
Jaret Burkett
2023-11-29 10:13:31 -07:00
parent be815f9c47
commit 7624241032
4 changed files with 119 additions and 10 deletions

View File

@@ -195,22 +195,34 @@ class SDTrainer(BaseSDTrainProcess):
if batch.unconditional_latents is not None:
# unconditional latents are the "neutral" images. Add noise here identical to
# the noise added to the conditional latents, at the same timesteps
unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
batch.unconditional_latents, noise, timesteps
)
# unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(
# batch.unconditional_latents, noise, timesteps
# )
unconditional_noisy_latents = self.sd.add_noise(batch.unconditional_latents, noise, timesteps)
# calculate the differential between our conditional (target image) and out unconditional (neutral image)
target_differential_noise = unconditional_noisy_latents - conditional_noisy_latents
target_differential_noise = target_differential_noise.detach()
# Calculate the mean along dim=1, keep dimensions
mean_chan = torch.abs(torch.mean(target_differential_noise, dim=1, keepdim=True))
# Create a mask with 0s where values are between 0.0 and 0.01, otherwise 1s
mask = torch.where((mean_chan >= 0.0) & (mean_chan <= 0.01), 0.0, 1.0)
# Duplicate the mask along dim 1 to match the shape of target_differential_noise
mask = mask.expand_as(target_differential_noise)
# this mask is now a 1 for our target differential and 0 for everything else
# add the target differential to the target latents as if it were noise with the scheduler, scaled to
# the current timestep. Scaling the noise here is important as it scales our guidance to the current
# timestep. This is the key to making the guidance work.
guidance_latents = self.sd.noise_scheduler.add_noise(
conditional_noisy_latents,
target_differential_noise,
timesteps
)
# guidance_latents = self.sd.noise_scheduler.add_noise(
# conditional_noisy_latents,
# target_differential_noise,
# timesteps
# )
guidance_latents = self.sd.add_noise(conditional_noisy_latents, target_differential_noise, timesteps)
# Disable the LoRA network so we can predict parent network knowledge without it
self.network.is_active = False
@@ -254,7 +266,22 @@ class SDTrainer(BaseSDTrainProcess):
reduction="none"
)
# multiply by our mask
loss = loss * mask
loss = loss.mean([1, 2, 3])
# calculate inverse to match baseline prediction
unmasked_prior_loss = torch.nn.functional.mse_loss(
baseline_prediction.float(),
prediction.float(),
reduction="none"
)
# multiply by our mask
unmasked_prior_loss = unmasked_prior_loss * (1.0 - mask)
# add the unmasked prior loss to the masked loss
unmasked_prior_loss = unmasked_prior_loss.mean([1, 2, 3])
loss = loss + unmasked_prior_loss
loss = self.apply_snr(loss, timesteps)
loss = loss.mean()
loss.backward()