More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes

This commit is contained in:
Jaret Burkett
2023-12-15 06:02:10 -07:00
parent e5177833b2
commit 39870411d8
14 changed files with 3501 additions and 106 deletions

View File

@@ -776,15 +776,19 @@ def apply_snr_weight(
):
# will get it from noise scheduler if exist or will calculate it if not
all_snr = get_all_snr(noise_scheduler, loss.device)
step_indices = []
for t in timesteps:
for i, st in enumerate(noise_scheduler.timesteps):
if st == t:
step_indices.append(i)
break
# step_indices = []
# for t in timesteps:
# for i, st in enumerate(noise_scheduler.timesteps):
# if st == t:
# step_indices.append(i)
# break
# this breaks on some schedulers
# step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
snr = torch.stack([all_snr[t] for t in step_indices])
offset = 0
if noise_scheduler.timesteps[0] == 1000:
offset = 1
snr = torch.stack([all_snr[t - offset] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
if fixed:
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr