Added flat snr gamma vs min. Fixes timestep timing

This commit is contained in:
Jaret Burkett
2023-10-29 15:41:55 -06:00
parent 3097865203
commit 436a09430e
4 changed files with 18 additions and 9 deletions

View File

@@ -688,13 +688,17 @@ def apply_snr_weight(
loss,
timesteps,
noise_scheduler: Union['DDPMScheduler'],
gamma
gamma,
fixed=False,
):
# will get it form noise scheduler if exist or will calculate it if not
# will get it from noise scheduler if exist or will calculate it if not
all_snr = get_all_snr(noise_scheduler, loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
else:
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
loss = loss * snr_weight
return loss