Added flat snr gamma vs min. Fixes timestep timing

This commit is contained in:
Jaret Burkett
2023-10-29 15:41:55 -06:00
parent 3097865203
commit 436a09430e
4 changed files with 18 additions and 9 deletions

View File

@@ -145,7 +145,11 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr:
# add snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True)
elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)