diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index e71c96015..6711ca16e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [ ('Pad conds', 'pad_cond_uncond'), ('VAE Encoder', 'sd_vae_encode_method'), ('VAE Decoder', 'sd_vae_decode_method'), + ('Refiner', 'sd_refiner_checkpoint'), + ('Refiner switch at', 'sd_refiner_switch_at'), ] diff --git a/modules/processing.py b/modules/processing.py index 317450065..cb5e3f725 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -178,6 +178,8 @@ class StableDiffusionProcessing: self.extra_network_data = None self.seeds = None self.subseeds = None + self.recorded_checkpoint = None + self.recorded_checkpoint_hash = None self.step_multiplier = 1 self.cached_uc = StableDiffusionProcessing.cached_uc @@ -186,6 +188,7 @@ class StableDiffusionProcessing: self.c = None self.user = None + self.image_conditioning = None @property def sd_model(self): @@ -377,6 +380,54 @@ class StableDiffusionProcessing: """Returns whether generated images need to be written to disk""" return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped) + def run_refiner(self, samples): + shared.state.nextjob() + + stopped_at = self.sampler.stop_at + self.sampler = None + + a_is_sdxl = shared.sd_model.is_sdxl + + decoded_samples = decode_latent_batch(shared.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + self.recorded_checkpoint = shared.sd_model.sd_checkpoint_info.name_for_extra + self.recorded_checkpoint_hash = shared.sd_model.sd_model_hash + self.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + self.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + self.setup_conds() + + b_is_sdxl = shared.sd_model.is_sdxl + + if a_is_sdxl != b_is_sdxl: + decoded_samples = torch.stack(decoded_samples).float() + decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) + latent = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model) + else: + latent = samples + + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + denoising_strength = self.denoising_strength + + self.denoising_strength = 1.0 - stopped_at / self.steps + self.image_conditioning = txt2img_image_conditioning(shared.sd_model, latent, self.width, self.height) + self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model) + samples = self.sampler.sample_img2img(self, latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1)) + + self.denoising_strength = denoising_strength + + return samples + class Processed: def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""): @@ -553,6 +604,9 @@ class DecodedSamples(list): def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): + if getattr(batch, 'already_decoded', False): + return batch + samples = DecodedSamples() for i in range(batch.shape[0]): @@ -632,8 +686,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", - "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra), + "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else p.recorded_checkpoint_hash or shared.sd_model.sd_model_hash), + "Model": (None if not opts.add_model_name_to_info else p.recorded_checkpoint or shared.sd_model.sd_checkpoint_info.name_for_extra), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -666,6 +720,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: + # after running refiner, the refiner model is not unloaded - webui swaps back to main model here + if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint: + sd_models.reload_model_weights() + # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) @@ -737,6 +795,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] + have_refiner = shared.opts.sd_refiner_switch_at < 1.0 and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -750,6 +810,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter + if have_refiner: + state.job_count *= 2 + shared.total_tqdm.updateTotal(p.steps * state.job_count // 2) + for n in range(p.n_iter): p.iteration = n @@ -798,16 +862,19 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + if have_refiner: + p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1)) + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - if getattr(samples_ddim, 'already_decoded', False): - x_samples_ddim = samples_ddim - else: - if opts.sd_vae_decode_method != 'Full': - p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + if opts.sd_vae_decode_method != 'Full': + p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + if have_refiner: + samples_ddim = p.run_refiner(samples_ddim) + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -989,6 +1056,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = None def init(self, all_prompts, all_seeds, all_subseeds): + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + if self.enable_hr: if self.hr_checkpoint_name: self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) @@ -1065,8 +1134,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.extra_generation_params["Hires upscaler"] = self.hr_upscaler def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): - self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x diff --git a/modules/sd_models.py b/modules/sd_models.py index f60516046..981aa93d7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 4a8396f97..e3b15ee95 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -44,7 +44,7 @@ class VanillaStableDiffusionSampler: return 0 def launch_sampling(self, steps, func): - state.sampling_steps = steps + state.sampling_steps = self.stop_at if self.stop_at is not None else steps state.sampling_step = 0 try: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index db71a549a..359b2d52a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -305,7 +305,7 @@ class KDiffusionSampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): - state.sampling_steps = steps + state.sampling_steps = self.stop_at if self.stop_at is not None else steps state.sampling_step = 0 try: diff --git a/modules/shared.py b/modules/shared.py index 078e81352..ed8395dc4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -461,6 +461,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), + "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {