diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index ffc2573a..a68b1f20 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -906,6 +906,10 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps ) + elif self.train_config.noise_scheduler == 'flowmatch': + self.sd.noise_scheduler.set_train_timesteps( + num_train_timesteps, device=self.device_torch + ) else: self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 7a513758..a38fcce0 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -3,6 +3,7 @@ from typing import Union from diffusers import FlowMatchEulerDiscreteScheduler import torch + class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: @@ -30,10 +31,28 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - n_dim = original_samples.ndim - sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise return noisy_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: - return sample \ No newline at end of file + return sample + + def set_train_timesteps(self, num_timesteps, device): + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = ((1 - t) * 1000) + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + return timesteps