diff --git a/backend/loader.py b/backend/loader.py index 4f1f73c2..7efd2aa7 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -244,6 +244,7 @@ def split_state_dict(sd, additional_state_dicts: list = None): guess.clip_target = guess.clip_target(sd) guess.model_type = guess.model_type(sd) + guess.ztsnr = 'ztsnr' in sd state_dict = { guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix), diff --git a/modules/processing.py b/modules/processing.py index 9bc3333d..5f821af5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -974,7 +974,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: sd_models.apply_alpha_schedule_override(p.sd_model, p) sigmas_backup = None - if opts.sd_noise_schedule == "Zero Terminal SNR" and p is not None: + if (opts.sd_noise_schedule == "Zero Terminal SNR" or (hasattr(p.sd_model.model_config, 'ztsnr') and p.sd_model.model_config.ztsnr)) and p is not None: p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule sigmas_backup = p.sd_model.forge_objects.unet.model.predictor.sigmas p.sd_model.forge_objects.unet.model.predictor.set_sigmas(rescale_zero_terminal_snr_sigmas(p.sd_model.forge_objects.unet.model.predictor.sigmas))