From d50f390c7e470761c5734221133bedc7f0febb65 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 31 Oct 2024 22:57:26 -0700 Subject: [PATCH] rewrite mu shift --- backend/diffusion_engine/flux.py | 12 ++++++++++-- backend/modules/k_prediction.py | 30 ++++++++++++++++++++---------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/backend/diffusion_engine/flux.py b/backend/diffusion_engine/flux.py index 8d4589c2..41b507ab 100644 --- a/backend/diffusion_engine/flux.py +++ b/backend/diffusion_engine/flux.py @@ -33,9 +33,17 @@ class Flux(ForgeDiffusionEngine): vae = VAE(model=huggingface_components['vae']) if 'schnell' in estimated_config.huggingface_repo.lower(): - k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000) + k_predictor = PredictionFlux( + mu=1.0 + ) else: - k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000) + k_predictor = PredictionFlux( + seq_len=4096, + base_seq_len=256, + max_seq_len=4096, + base_shift=0.5, + max_shift=1.15, + ) self.use_distilled_cfg_scale = True unet = UnetPatcher.from_model( diff --git a/backend/modules/k_prediction.py b/backend/modules/k_prediction.py index 18fb2aac..e6c3c150 100644 --- a/backend/modules/k_prediction.py +++ b/backend/modules/k_prediction.py @@ -2,6 +2,9 @@ import math import torch import numpy as np +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.flux.pipeline_flux import calculate_shift + def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): betas = [] @@ -41,10 +44,6 @@ def time_snr_shift(alpha, t): return alpha * t / (1 + (alpha - 1) * t) -def flux_time_shift(mu, sigma, t): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -252,11 +251,22 @@ class PredictionFlow(AbstractPrediction): class PredictionFlux(AbstractPrediction): - def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000): - super().__init__(sigma_data=sigma_data, prediction_type=prediction_type) - self.shift = shift - ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps)) - self.register_buffer('sigmas', ts) + def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None): + super().__init__(sigma_data=1.0, prediction_type='const') + self.mu = mu + self.pseudo_timestep_range = pseudo_timestep_range + self.apply_mu_transform(seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift, mu=mu) + + def apply_mu_transform(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, mu=None): + # TODO: Add an UI option to let user choose whether to call this in each generation to bind latent size to sigmas + # And some cases may want their own mu values or other parameters + if mu is None: + self.mu = calculate_shift(image_seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift) + else: + self.mu = mu + sigmas = torch.arange(1, self.pseudo_timestep_range + 1, 1) / self.pseudo_timestep_range + sigmas = FlowMatchEulerDiscreteScheduler.time_shift(None, self.mu, 1.0, sigmas) + self.register_buffer('sigmas', sigmas) @property def sigma_min(self): @@ -270,7 +280,7 @@ class PredictionFlux(AbstractPrediction): return sigma def sigma(self, timestep): - return flux_time_shift(self.shift, 1.0, timestep) + return timestep def percent_to_sigma(self, percent): if percent <= 0.0: