mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added flat snr gamma vs min. Fixes timestep timing
This commit is contained in:
@@ -559,12 +559,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# for content / structure, it is best to favor earlier timesteps
|
||||
# for style, it is best to favor later timesteps
|
||||
|
||||
timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'style':
|
||||
timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'content':
|
||||
timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
if self.train_config.content_or_style == 'content':
|
||||
timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'style':
|
||||
timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timesteps = value_map(
|
||||
timesteps,
|
||||
|
||||
Reference in New Issue
Block a user