Added config setting to set the timestep type

This commit is contained in:
Jaret Burkett
2024-09-24 06:53:59 -06:00
parent 40a8ff5731
commit 04424fe2d6
2 changed files with 7 additions and 1 deletions

View File

@@ -913,10 +913,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
elif self.train_config.noise_scheduler == 'flowmatch':
linear_timesteps = any([
self.train_config.linear_timesteps,
self.train_config.linear_timesteps2,
self.train_config.timestep_type == 'linear',
])
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
linear=self.train_config.linear_timesteps or self.train_config.linear_timesteps2
linear=linear_timesteps
)
else:
self.sd.noise_scheduler.set_timesteps(

View File

@@ -373,6 +373,7 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
self.disable_sampling = kwargs.get('disable_sampling', False)