Added flat snr gamma vs min. Fixes timestep timing

This commit is contained in:
Jaret Burkett
2023-10-29 15:41:55 -06:00
parent 3097865203
commit 436a09430e
4 changed files with 18 additions and 9 deletions

View File

@@ -559,12 +559,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
# for content / structure, it is best to favor earlier timesteps
# for style, it is best to favor later timesteps
timesteps = torch.rand((batch_size,), device=latents.device)
orig_timesteps = torch.rand((batch_size,), device=latents.device)
if self.train_config.content_or_style == 'style':
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'content':
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
if self.train_config.content_or_style == 'content':
timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
elif self.train_config.content_or_style == 'style':
timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
timesteps = value_map(
timesteps,