mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 12:23:57 +00:00
59 lines
2.3 KiB
Python
59 lines
2.3 KiB
Python
from typing import Union
|
|
|
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
import torch
|
|
|
|
|
|
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|
|
|
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
|
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
|
schedule_timesteps = self.timesteps.to(device)
|
|
timesteps = timesteps.to(device)
|
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
|
|
|
sigma = sigmas[step_indices].flatten()
|
|
while len(sigma.shape) < n_dim:
|
|
sigma = sigma.unsqueeze(-1)
|
|
|
|
return sigma
|
|
|
|
def add_noise(
|
|
self,
|
|
original_samples: torch.Tensor,
|
|
noise: torch.Tensor,
|
|
timesteps: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
|
## Add noise according to flow matching.
|
|
## zt = (1 - texp) * x + texp * z1
|
|
|
|
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
|
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
|
|
|
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
|
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
|
t_01 = (timesteps / 1000).to(original_samples.device)
|
|
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
|
|
|
# n_dim = original_samples.ndim
|
|
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
|
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
|
return noisy_model_input
|
|
|
|
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
|
return sample
|
|
|
|
def set_train_timesteps(self, num_timesteps, device):
|
|
# distribute them closer to center. Inference distributes them as a bias toward first
|
|
# Generate values from 0 to 1
|
|
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
|
|
|
# Scale and reverse the values to go from 1000 to 0
|
|
timesteps = ((1 - t) * 1000)
|
|
|
|
# Sort the timesteps in descending order
|
|
timesteps, _ = torch.sort(timesteps, descending=True)
|
|
|
|
return timesteps
|