From 04424fe2d679aae7bd6f2f25dcd61db87e067f2a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 24 Sep 2024 06:53:59 -0600 Subject: [PATCH] Added config setting to set the timestep type --- jobs/process/BaseSDTrainProcess.py | 7 ++++++- toolkit/config_modules.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8c480147..5fb0d493 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -913,10 +913,15 @@ class BaseSDTrainProcess(BaseTrainProcess): num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps ) elif self.train_config.noise_scheduler == 'flowmatch': + linear_timesteps = any([ + self.train_config.linear_timesteps, + self.train_config.linear_timesteps2, + self.train_config.timestep_type == 'linear', + ]) self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, - linear=self.train_config.linear_timesteps or self.train_config.linear_timesteps2 + linear=linear_timesteps ) else: self.sd.noise_scheduler.set_timesteps( diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a1b0dd42..61b6a35c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -373,6 +373,7 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear self.linear_timesteps = kwargs.get('linear_timesteps', False) self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) self.disable_sampling = kwargs.get('disable_sampling', False)