Reworked timestep distribution on flowmatch sampler when training.

This commit is contained in:
Jaret Burkett
2024-08-08 06:01:45 -06:00
parent acafe9984f
commit e69a520616
2 changed files with 27 additions and 4 deletions

View File

@@ -906,6 +906,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
elif self.train_config.noise_scheduler == 'flowmatch':
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps, device=self.device_torch
)
else:
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch

View File

@@ -3,6 +3,7 @@ from typing import Union
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
@@ -30,10 +31,28 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
n_dim = original_samples.ndim
sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
return sample
def set_train_timesteps(self, num_timesteps, device):
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = ((1 - t) * 1000)
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
return timesteps