Added flux_shift as timestep type

This commit is contained in:
Jaret Burkett
2025-01-27 07:35:00 -07:00
parent 2141c6e06c
commit 34a1c6947a
3 changed files with 64 additions and 9 deletions

View File

@@ -79,7 +79,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.accelerator: Accelerator = get_accelerator()
if self.accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
diffusers.utils.logging.set_verbosity_error()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
@@ -1066,7 +1066,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
timestep_type=timestep_type
timestep_type=timestep_type,
latents=latents
)
else:
self.sd.noise_scheduler.set_timesteps(

View File

@@ -3,12 +3,27 @@ from typing import Union
from torch.distributions import LogNormal
from diffusers import FlowMatchEulerDiscreteScheduler
import torch
import numpy as np
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_noise_sigma = 1.0
self.timestep_type = "linear"
with torch.no_grad():
# create weights for timesteps
@@ -89,7 +104,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'):
def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None):
self.timestep_type = timestep_type
if timestep_type == 'linear':
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
@@ -108,6 +124,42 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'flux_shift':
# matches inference dynamic shifting
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
)
sigmas = timesteps / self.config.num_train_timesteps
if latents is None:
raise ValueError('latents is None')
h = latents.shape[2] // 2 # Divide by ph
w = latents.shape[3] // 2 # Divide by pw
image_seq_len = h * w
# todo need to know the mu for the shift
mu = calculate_shift(
image_seq_len,
self.config.get("base_image_seq_len", 256),
self.config.get("max_image_seq_len", 4096),
self.config.get("base_shift", 0.5),
self.config.get("max_shift", 1.16),
)
sigmas = self.time_shift(mu, 1.0, sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self.timesteps = timesteps.to(device=device)
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.75
@@ -128,5 +180,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
timesteps, _ = torch.sort(timesteps, descending=True)
timesteps = timesteps.to(torch.int)
self.timesteps = timesteps.to(device=device)
return timesteps
else:
raise ValueError(f"Invalid timestep type: {timestep_type}")

View File

@@ -1027,9 +1027,9 @@ class StableDiffusion:
if self.model_config.use_flux_cfg:
pipeline = FluxWithCFGPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
transformer=unwrap_model(self.unet),
text_encoder=unwrap_model(self.text_encoder[0]),
text_encoder_2=unwrap_model(self.text_encoder[1]),
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
@@ -1039,9 +1039,9 @@ class StableDiffusion:
else:
pipeline = FluxPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
transformer=unwrap_model(self.unet),
text_encoder=unwrap_model(self.text_encoder[0]),
text_encoder_2=unwrap_model(self.text_encoder[1]),
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,