More work on custom adapter

This commit is contained in:
Jaret Burkett
2024-01-16 17:41:26 -07:00
parent eebd3c8212
commit 655533d4c7
2 changed files with 3 additions and 2 deletions

View File

@@ -788,7 +788,7 @@ def apply_snr_weight(
offset = 0
if noise_scheduler.timesteps[0] == 1000:
offset = 1
snr = torch.stack([all_snr[t - offset] for t in timesteps])
snr = torch.stack([all_snr[(t - offset).int()] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr