mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added flux_shift as timestep type
This commit is contained in:
@@ -79,7 +79,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.accelerator: Accelerator = get_accelerator()
|
||||
if self.accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
@@ -1066,7 +1066,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_train_timesteps(
|
||||
num_train_timesteps,
|
||||
device=self.device_torch,
|
||||
timestep_type=timestep_type
|
||||
timestep_type=timestep_type,
|
||||
latents=latents
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
|
||||
@@ -3,12 +3,27 @@ from typing import Union
|
||||
from torch.distributions import LogNormal
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_noise_sigma = 1.0
|
||||
self.timestep_type = "linear"
|
||||
|
||||
with torch.no_grad():
|
||||
# create weights for timesteps
|
||||
@@ -89,7 +104,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
return sample
|
||||
|
||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'):
|
||||
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
|
||||
self.timestep_type = timestep_type
|
||||
if timestep_type == 'linear':
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
|
||||
self.timesteps = timesteps
|
||||
@@ -108,6 +124,42 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
elif timestep_type == 'flux_shift':
|
||||
# matches inference dynamic shifting
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
|
||||
if latents is None:
|
||||
raise ValueError('latents is None')
|
||||
|
||||
h = latents.shape[2] // 2 # Divide by ph
|
||||
w = latents.shape[3] // 2 # Divide by pw
|
||||
image_seq_len = h * w
|
||||
|
||||
# todo need to know the mu for the shift
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.config.get("base_image_seq_len", 256),
|
||||
self.config.get("max_image_seq_len", 4096),
|
||||
self.config.get("base_shift", 0.5),
|
||||
self.config.get("max_shift", 1.16),
|
||||
)
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = sigmas
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
return timesteps
|
||||
|
||||
elif timestep_type == 'lognorm_blend':
|
||||
# disgtribute timestepd to the center/early and blend in linear
|
||||
alpha = 0.75
|
||||
@@ -128,5 +180,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
timesteps, _ = torch.sort(timesteps, descending=True)
|
||||
|
||||
timesteps = timesteps.to(torch.int)
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
return timesteps
|
||||
else:
|
||||
raise ValueError(f"Invalid timestep type: {timestep_type}")
|
||||
|
||||
@@ -1027,9 +1027,9 @@ class StableDiffusion:
|
||||
if self.model_config.use_flux_cfg:
|
||||
pipeline = FluxWithCFGPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
transformer=unwrap_model(self.unet),
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
text_encoder_2=unwrap_model(self.text_encoder[1]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
@@ -1039,9 +1039,9 @@ class StableDiffusion:
|
||||
else:
|
||||
pipeline = FluxPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder[0],
|
||||
text_encoder_2=self.text_encoder[1],
|
||||
transformer=unwrap_model(self.unet),
|
||||
text_encoder=unwrap_model(self.text_encoder[0]),
|
||||
text_encoder_2=unwrap_model(self.text_encoder[1]),
|
||||
tokenizer=self.tokenizer[0],
|
||||
tokenizer_2=self.tokenizer[1],
|
||||
scheduler=noise_scheduler,
|
||||
|
||||
Reference in New Issue
Block a user