mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes
This commit is contained in:
@@ -776,15 +776,19 @@ def apply_snr_weight(
|
||||
):
|
||||
# will get it from noise scheduler if exist or will calculate it if not
|
||||
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||
step_indices = []
|
||||
for t in timesteps:
|
||||
for i, st in enumerate(noise_scheduler.timesteps):
|
||||
if st == t:
|
||||
step_indices.append(i)
|
||||
break
|
||||
# step_indices = []
|
||||
# for t in timesteps:
|
||||
# for i, st in enumerate(noise_scheduler.timesteps):
|
||||
# if st == t:
|
||||
# step_indices.append(i)
|
||||
# break
|
||||
# this breaks on some schedulers
|
||||
# step_indices = [(noise_scheduler.timesteps == t).nonzero().item() for t in timesteps]
|
||||
snr = torch.stack([all_snr[t] for t in step_indices])
|
||||
|
||||
offset = 0
|
||||
if noise_scheduler.timesteps[0] == 1000:
|
||||
offset = 1
|
||||
snr = torch.stack([all_snr[t - offset] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
if fixed:
|
||||
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
||||
|
||||
Reference in New Issue
Block a user