rewrite mu shift

This commit is contained in:
layerdiffusion
2024-10-31 22:57:26 -07:00
parent d1b996f3c3
commit d50f390c7e
2 changed files with 30 additions and 12 deletions

View File

@@ -33,9 +33,17 @@ class Flux(ForgeDiffusionEngine):
vae = VAE(model=huggingface_components['vae'])
if 'schnell' in estimated_config.huggingface_repo.lower():
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.0, timesteps=10000)
k_predictor = PredictionFlux(
mu=1.0
)
else:
k_predictor = PredictionFlux(sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000)
k_predictor = PredictionFlux(
seq_len=4096,
base_seq_len=256,
max_seq_len=4096,
base_shift=0.5,
max_shift=1.15,
)
self.use_distilled_cfg_scale = True
unet = UnetPatcher.from_model(

View File

@@ -2,6 +2,9 @@ import math
import torch
import numpy as np
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.flux.pipeline_flux import calculate_shift
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
betas = []
@@ -41,10 +44,6 @@ def time_snr_shift(alpha, t):
return alpha * t / (1 + (alpha - 1) * t)
def flux_time_shift(mu, sigma, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
@@ -252,11 +251,22 @@ class PredictionFlow(AbstractPrediction):
class PredictionFlux(AbstractPrediction):
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.15, timesteps=10000):
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer('sigmas', ts)
def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None):
super().__init__(sigma_data=1.0, prediction_type='const')
self.mu = mu
self.pseudo_timestep_range = pseudo_timestep_range
self.apply_mu_transform(seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift, mu=mu)
def apply_mu_transform(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, mu=None):
# TODO: Add an UI option to let user choose whether to call this in each generation to bind latent size to sigmas
# And some cases may want their own mu values or other parameters
if mu is None:
self.mu = calculate_shift(image_seq_len=seq_len, base_seq_len=base_seq_len, max_seq_len=max_seq_len, base_shift=base_shift, max_shift=max_shift)
else:
self.mu = mu
sigmas = torch.arange(1, self.pseudo_timestep_range + 1, 1) / self.pseudo_timestep_range
sigmas = FlowMatchEulerDiscreteScheduler.time_shift(None, self.mu, 1.0, sigmas)
self.register_buffer('sigmas', sigmas)
@property
def sigma_min(self):
@@ -270,7 +280,7 @@ class PredictionFlux(AbstractPrediction):
return sigma
def sigma(self, timestep):
return flux_time_shift(self.shift, 1.0, timestep)
return timestep
def percent_to_sigma(self, percent):
if percent <= 0.0: