diff --git a/modules/processing.py b/modules/processing.py index 44d47e8ca..efa6eafa8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -377,6 +377,9 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + def get_conds(self): + return self.c, self.uc + def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -611,6 +614,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: + # after running refiner, the refiner model is not unloaded - webui swaps back to main model here + if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint: + sd_models.reload_model_weights() + # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) @@ -710,6 +717,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break + sd_models.reload_model_weights() # model can be changed for example by refiner + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -1201,6 +1210,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.extra_network_data) + def get_conds(self): + if self.is_hr_pass: + return self.hr_c, self.hr_uc + + return super().get_conds() + + def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() diff --git a/modules/sd_models.py b/modules/sd_models.py index 7a866a07d..a178adcac 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -295,11 +295,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) @@ -624,8 +640,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): timer.record("send model to device") model_data.set_sd_model(already_loaded) - shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title - shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title + shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index d826222cd..a532e0137 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model, sampler): + def __init__(self, sampler): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.sampler = sampler + self.model_wrap = None + self.p = None + + @property + def inner_model(self): + raise NotImplementedError() + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -68,10 +76,21 @@ class CFGDenoiser(torch.nn.Module): def get_pred_x0(self, x_in, x_out, sigma): return x_out + def update_inner_model(self): + self.model_wrap = None + + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 97bc08041..35c4d657f 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -3,7 +3,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules.shared import opts, state import k_diffusion.sampling @@ -131,6 +131,35 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() +def apply_refiner(sampler): + completed_ratio = sampler.step / sampler.steps + + if completed_ratio <= shared.opts.sd_refiner_switch_at: + return False + + if shared.opts.sd_refiner_checkpoint == "None": + return False + + if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: + return False + + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + sampler.p.setup_conds() + sampler.update_inner_model() + + return True + + class TorchHijack: """This is here to replace torch.randn_like of k-diffusion. @@ -176,8 +205,9 @@ class Sampler: self.conditioning_key = shared.sd_model.model.conditioning_key - self.model_wrap = None + self.p = None self.model_wrap_cfg = None + self.sampler_extra_args = None def callback_state(self, d): step = d['i'] @@ -189,6 +219,7 @@ class Sampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -208,6 +239,8 @@ class Sampler: return p.steps def initialize(self, p) -> dict: + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 5613b8c18..95a43ceff 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra -from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared @@ -53,17 +52,24 @@ k_diffusion_scheduler = { } +class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap + + class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - super().__init__(funcname) - self.extra_params = sampler_extra_params.get(funcname, []) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserKDiffusion(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -144,7 +150,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -152,7 +158,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -184,13 +190,14 @@ class KDiffusionSampler(sd_samplers_common.Sampler): extra_params_kwargs['noise_sampler'] = noise_sampler self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index f61799a82..16572c7e0 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -45,10 +45,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): class CFGDenoiserTimesteps(CFGDenoiser): - def __init__(self, model, sampler): - super().__init__(model, sampler) + def __init__(self, sampler): + super().__init__(sampler) - self.alphas = model.inner_model.alphas_cumprod + self.alphas = shared.sd_model.alphas_cumprod def get_pred_x0(self, x_in, x_out, sigma): ts = int(sigma.item()) @@ -61,6 +61,14 @@ class CFGDenoiserTimesteps(CFGDenoiser): return pred_x0 + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + class CompVisSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): @@ -69,9 +77,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_option_field = 'eta_ddim' self.eta_infotext_field = 'Eta DDIM' - denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser - self.model_wrap = denoiser(sd_model) - self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserTimesteps(self) def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -107,7 +113,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -115,7 +121,7 @@ class CompVisSampler(sd_samplers_common.Sampler): 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -133,13 +139,14 @@ class CompVisSampler(sd_samplers_common.Sampler): extra_params_kwargs['timesteps'] = timesteps self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True diff --git a/modules/shared_options.py b/modules/shared_options.py index 9ae51f186..1e5b64eaf 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -140,6 +140,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), + "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"), + "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {