mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
Merge pull request #2119 from lllyasviel/fix/ztsnr
This commit is contained in:
@@ -45,6 +45,26 @@ def flux_time_shift(mu, sigma, t):
|
|||||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||||
|
|
||||||
|
|
||||||
class AbstractPrediction(torch.nn.Module):
|
class AbstractPrediction(torch.nn.Module):
|
||||||
def __init__(self, sigma_data=1.0, prediction_type='epsilon'):
|
def __init__(self, sigma_data=1.0, prediction_type='epsilon'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -114,6 +134,10 @@ class Prediction(AbstractPrediction):
|
|||||||
self.register_buffer('log_sigmas', sigmas.log().float())
|
self.register_buffer('log_sigmas', sigmas.log().float())
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def set_sigmas(self, sigmas):
|
||||||
|
self.register_buffer('sigmas', sigmas.float())
|
||||||
|
self.register_buffer('log_sigmas', sigmas.log().float())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sigma_min(self):
|
def sigma_min(self):
|
||||||
return self.sigmas[0]
|
return self.sigmas[0]
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from modules.sd_models import apply_token_merging, forge_model_reload
|
|||||||
from modules_forge.utils import apply_circular_forge
|
from modules_forge.utils import apply_circular_forge
|
||||||
from modules_forge import main_entry
|
from modules_forge import main_entry
|
||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
|
from backend.modules.k_prediction import rescale_zero_terminal_snr_sigmas
|
||||||
|
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
@@ -969,25 +970,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
|
# TODO: This currently seems broken. It should be fixed or removed.
|
||||||
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
||||||
|
|
||||||
alphas_cumprod_modifiers = p.sd_model.forge_objects.unet.model_options.get('alphas_cumprod_modifiers', [])
|
sigmas_backup = None
|
||||||
alphas_cumprod_backup = None
|
if opts.sd_noise_schedule == "Zero Terminal SNR" and p is not None:
|
||||||
|
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||||
if len(alphas_cumprod_modifiers) > 0:
|
sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas
|
||||||
alphas_cumprod_backup = p.sd_model.alphas_cumprod
|
p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas))
|
||||||
for modifier in alphas_cumprod_modifiers:
|
|
||||||
p.sd_model.alphas_cumprod = modifier(p.sd_model.alphas_cumprod)
|
|
||||||
p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5)
|
|
||||||
|
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
for x_sample in samples_ddim:
|
for x_sample in samples_ddim:
|
||||||
p.latents_after_sampling.append(x_sample)
|
p.latents_after_sampling.append(x_sample)
|
||||||
|
|
||||||
if alphas_cumprod_backup is not None:
|
if sigmas_backup is not None:
|
||||||
p.sd_model.alphas_cumprod = alphas_cumprod_backup
|
p.sd_model.forge_objects.unet.model.predictor.set_sigmas(sigmas_backup)
|
||||||
p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5)
|
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
ps = scripts.PostSampleArgs(samples_ddim)
|
ps = scripts.PostSampleArgs(samples_ddim)
|
||||||
|
|||||||
Reference in New Issue
Block a user