mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
187 lines
7.1 KiB
Python
187 lines
7.1 KiB
Python
import math
|
|
from typing import Union
|
|
from torch.distributions import LogNormal
|
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def calculate_shift(
|
|
image_seq_len,
|
|
base_seq_len: int = 256,
|
|
max_seq_len: int = 4096,
|
|
base_shift: float = 0.5,
|
|
max_shift: float = 1.16,
|
|
):
|
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
|
b = base_shift - m * base_seq_len
|
|
mu = image_seq_len * m + b
|
|
return mu
|
|
|
|
|
|
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.init_noise_sigma = 1.0
|
|
self.timestep_type = "linear"
|
|
|
|
with torch.no_grad():
|
|
# create weights for timesteps
|
|
num_timesteps = 1000
|
|
# Bell-Shaped Mean-Normalized Timestep Weighting
|
|
# bsmntw? need a better name
|
|
|
|
x = torch.arange(num_timesteps, dtype=torch.float32)
|
|
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
|
|
|
|
# Shift minimum to 0
|
|
y_shifted = y - y.min()
|
|
|
|
# Scale to make mean 1
|
|
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
|
|
|
# only do half bell
|
|
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
|
|
|
# flatten second half to max
|
|
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
|
|
|
# Create linear timesteps from 1000 to 0
|
|
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
|
|
|
self.linear_timesteps = timesteps
|
|
self.linear_timesteps_weights = bsmntw_weighing
|
|
self.linear_timesteps_weights2 = hbsmntw_weighing
|
|
pass
|
|
|
|
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
|
# Get the indices of the timesteps
|
|
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
|
|
|
# Get the weights for the timesteps
|
|
if v2:
|
|
weights = self.linear_timesteps_weights2[step_indices].flatten()
|
|
else:
|
|
weights = self.linear_timesteps_weights[step_indices].flatten()
|
|
|
|
return weights
|
|
|
|
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
|
|
sigmas = self.sigmas.to(device=device, dtype=dtype)
|
|
schedule_timesteps = self.timesteps.to(device)
|
|
timesteps = timesteps.to(device)
|
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
|
|
|
sigma = sigmas[step_indices].flatten()
|
|
while len(sigma.shape) < n_dim:
|
|
sigma = sigma.unsqueeze(-1)
|
|
|
|
return sigma
|
|
|
|
def add_noise(
|
|
self,
|
|
original_samples: torch.Tensor,
|
|
noise: torch.Tensor,
|
|
timesteps: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
|
## Add noise according to flow matching.
|
|
## zt = (1 - texp) * x + texp * z1
|
|
|
|
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
|
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
|
|
|
# timestep needs to be in [0, 1], we store them in [0, 1000]
|
|
# noisy_sample = (1 - timestep) * latent + timestep * noise
|
|
t_01 = (timesteps / 1000).to(original_samples.device)
|
|
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
|
|
|
|
# n_dim = original_samples.ndim
|
|
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
|
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
|
return noisy_model_input
|
|
|
|
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
|
return sample
|
|
|
|
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
|
|
self.timestep_type = timestep_type
|
|
if timestep_type == 'linear':
|
|
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
|
self.timesteps = timesteps
|
|
return timesteps
|
|
elif timestep_type == 'sigmoid':
|
|
# distribute them closer to center. Inference distributes them as a bias toward first
|
|
# Generate values from 0 to 1
|
|
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
|
|
|
|
# Scale and reverse the values to go from 1000 to 0
|
|
timesteps = ((1 - t) * 1000)
|
|
|
|
# Sort the timesteps in descending order
|
|
timesteps, _ = torch.sort(timesteps, descending=True)
|
|
|
|
self.timesteps = timesteps.to(device=device)
|
|
|
|
return timesteps
|
|
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
|
|
# matches inference dynamic shifting
|
|
timesteps = np.linspace(
|
|
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
|
|
)
|
|
|
|
sigmas = timesteps / self.config.num_train_timesteps
|
|
|
|
if latents is None:
|
|
raise ValueError('latents is None')
|
|
|
|
h = latents.shape[2] // 2 # Divide by ph
|
|
w = latents.shape[3] // 2 # Divide by pw
|
|
image_seq_len = h * w
|
|
|
|
# todo need to know the mu for the shift
|
|
mu = calculate_shift(
|
|
image_seq_len,
|
|
self.config.get("base_image_seq_len", 256),
|
|
self.config.get("max_image_seq_len", 4096),
|
|
self.config.get("base_shift", 0.5),
|
|
self.config.get("max_shift", 1.16),
|
|
)
|
|
sigmas = self.time_shift(mu, 1.0, sigmas)
|
|
|
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
|
timesteps = sigmas * self.config.num_train_timesteps
|
|
|
|
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
|
|
|
self.timesteps = timesteps.to(device=device)
|
|
self.sigmas = sigmas
|
|
|
|
self.timesteps = timesteps.to(device=device)
|
|
return timesteps
|
|
|
|
elif timestep_type == 'lognorm_blend':
|
|
# disgtribute timestepd to the center/early and blend in linear
|
|
alpha = 0.75
|
|
|
|
lognormal = LogNormal(loc=0, scale=0.333)
|
|
|
|
# Sample from the distribution
|
|
t1 = lognormal.sample((int(num_timesteps * alpha),)).to(device)
|
|
|
|
# Scale and reverse the values to go from 1000 to 0
|
|
t1 = ((1 - t1/t1.max()) * 1000)
|
|
|
|
# add half of linear
|
|
t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device)
|
|
timesteps = torch.cat((t1, t2))
|
|
|
|
# Sort the timesteps in descending order
|
|
timesteps, _ = torch.sort(timesteps, descending=True)
|
|
|
|
timesteps = timesteps.to(torch.int)
|
|
self.timesteps = timesteps.to(device=device)
|
|
return timesteps
|
|
else:
|
|
raise ValueError(f"Invalid timestep type: {timestep_type}")
|