mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked timestep distribution on flowmatch sampler when training.
This commit is contained in:
@@ -906,6 +906,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
|
||||
)
|
||||
elif self.train_config.noise_scheduler == 'flowmatch':
|
||||
self.sd.noise_scheduler.set_train_timesteps(
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
|
||||
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
|
||||
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
||||
@@ -30,10 +31,28 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
n_dim = original_samples.ndim
|
||||
sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
||||
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
||||
t_01 = (timesteps / 1000).to(original_samples.device)
|
||||
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
||||
|
||||
# n_dim = original_samples.ndim
|
||||
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device):
|
||||
# distribute them closer to center. Inference distributes them as a bias toward first
|
||||
# Generate values from 0 to 1
|
||||
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
||||
|
||||
# Scale and reverse the values to go from 1000 to 0
|
||||
timesteps = ((1 - t) * 1000)
|
||||
|
||||
# Sort the timesteps in descending order
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
return timesteps
|
||||
|
||||
Reference in New Issue
Block a user