From a6aa4b2c7dffa88e8a46898385785bf79a232c4d Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 11 Aug 2024 13:06:08 -0600 Subject: [PATCH] Added ability to set timesteps to linear for flowmatching schedule --- jobs/process/BaseSDTrainProcess.py | 4 +++- toolkit/config_modules.py | 1 + toolkit/samplers/custom_flowmatch_sampler.py | 25 ++++++++++++-------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index a68b1f20..d83c6081 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -908,7 +908,9 @@ class BaseSDTrainProcess(BaseTrainProcess): ) elif self.train_config.noise_scheduler == 'flowmatch': self.sd.noise_scheduler.set_train_timesteps( - num_train_timesteps, device=self.device_torch + num_train_timesteps, + device=self.device_torch, + linear=self.train_config.linear_timesteps ) else: self.sd.noise_scheduler.set_timesteps( diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 491c79b5..a8041c5a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -356,6 +356,7 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) + self.linear_timesteps = kwargs.get('linear_timesteps', False) class ModelConfig: diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 0a1d7f45..5782f45f 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -44,17 +44,22 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device): - # distribute them closer to center. Inference distributes them as a bias toward first - # Generate values from 0 to 1 - t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) - # Scale and reverse the values to go from 1000 to 0 - timesteps = ((1 - t) * 1000) + # Scale and reverse the values to go from 1000 to 0 + timesteps = ((1 - t) * 1000) - # Sort the timesteps in descending order - timesteps, _ = torch.sort(timesteps, descending=True) + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) - self.timesteps = timesteps.to(device=device) + self.timesteps = timesteps.to(device=device) - return timesteps + return timesteps