Added timestep modifications to lcm scheduler for more evenly spaced timesteps

This commit is contained in:
Jaret Burkett
2023-11-17 23:26:52 -07:00
parent 6280284d8b
commit fbec68681d
4 changed files with 277 additions and 175 deletions

View File

@@ -688,7 +688,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
with self.timer('prepare_noise'):
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
if self.train_config.noise_scheduler == 'lcm':
if self.train_config.noise_scheduler in ['custom_lcm']:
# we store this value on our custom one
self.sd.noise_scheduler.set_timesteps(
self.sd.noise_scheduler.train_timesteps, device=self.device_torch
)
elif self.train_config.noise_scheduler in ['lcm']:
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
@@ -727,12 +732,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
elif self.train_config.content_or_style == 'balanced':
timesteps = torch.randint(
min_noise_steps,
max_noise_steps,
(batch_size,),
device=self.device_torch
)
if min_noise_steps == max_noise_steps:
timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
else:
timesteps = torch.randint(
min_noise_steps,
max_noise_steps,
(batch_size,),
device=self.device_torch
)
timesteps = timesteps.long()
else:
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")