diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 058ba375..d946d445 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -79,7 +79,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.accelerator: Accelerator = get_accelerator() if self.accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() + diffusers.utils.logging.set_verbosity_error() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() @@ -1066,7 +1066,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, - timestep_type=timestep_type + timestep_type=timestep_type, + latents=latents ) else: self.sd.noise_scheduler.set_timesteps( diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 6c5b90df..a4c53db1 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -3,12 +3,27 @@ from typing import Union from torch.distributions import LogNormal from diffusers import FlowMatchEulerDiscreteScheduler import torch +import numpy as np + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_noise_sigma = 1.0 + self.timestep_type = "linear" with torch.no_grad(): # create weights for timesteps @@ -89,7 +104,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device, timestep_type='linear'): + def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None): + self.timestep_type = timestep_type if timestep_type == 'linear': timesteps = torch.linspace(1000, 0, num_timesteps, device=device) self.timesteps = timesteps @@ -108,6 +124,42 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps + elif timestep_type == 'flux_shift': + # matches inference dynamic shifting + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps + ) + + sigmas = timesteps / self.config.num_train_timesteps + + if latents is None: + raise ValueError('latents is None') + + h = latents.shape[2] // 2 # Divide by ph + w = latents.shape[3] // 2 # Divide by pw + image_seq_len = h * w + + # todo need to know the mu for the shift + mu = calculate_shift( + image_seq_len, + self.config.get("base_image_seq_len", 256), + self.config.get("max_image_seq_len", 4096), + self.config.get("base_shift", 0.5), + self.config.get("max_shift", 1.16), + ) + sigmas = self.time_shift(mu, 1.0, sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.config.num_train_timesteps + + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas + + self.timesteps = timesteps.to(device=device) + return timesteps + elif timestep_type == 'lognorm_blend': # disgtribute timestepd to the center/early and blend in linear alpha = 0.75 @@ -128,5 +180,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): timesteps, _ = torch.sort(timesteps, descending=True) timesteps = timesteps.to(torch.int) + self.timesteps = timesteps.to(device=device) + return timesteps else: raise ValueError(f"Invalid timestep type: {timestep_type}") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 66469a9f..d54fc40e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1027,9 +1027,9 @@ class StableDiffusion: if self.model_config.use_flux_cfg: pipeline = FluxWithCFGPipeline( vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], + transformer=unwrap_model(self.unet), + text_encoder=unwrap_model(self.text_encoder[0]), + text_encoder_2=unwrap_model(self.text_encoder[1]), tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler, @@ -1039,9 +1039,9 @@ class StableDiffusion: else: pipeline = FluxPipeline( vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder[0], - text_encoder_2=self.text_encoder[1], + transformer=unwrap_model(self.unet), + text_encoder=unwrap_model(self.text_encoder[0]), + text_encoder_2=unwrap_model(self.text_encoder[1]), tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], scheduler=noise_scheduler,