mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added timestep modifications to lcm scheduler for more evenly spaced timesteps
This commit is contained in:
@@ -688,7 +688,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with self.timer('prepare_noise'):
|
||||
num_train_timesteps = self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
if self.train_config.noise_scheduler == 'lcm':
|
||||
if self.train_config.noise_scheduler in ['custom_lcm']:
|
||||
# we store this value on our custom one
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
self.sd.noise_scheduler.train_timesteps, device=self.device_torch
|
||||
)
|
||||
elif self.train_config.noise_scheduler in ['lcm']:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
|
||||
)
|
||||
@@ -727,12 +732,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
timesteps = torch.randint(
|
||||
min_noise_steps,
|
||||
max_noise_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
if min_noise_steps == max_noise_steps:
|
||||
timesteps = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
else:
|
||||
timesteps = torch.randint(
|
||||
min_noise_steps,
|
||||
max_noise_steps,
|
||||
(batch_size,),
|
||||
device=self.device_torch
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
|
||||
Reference in New Issue
Block a user