diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 1ddebaae..6e3f7b89 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -424,6 +424,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model # VAE / TE modules = [] + hr_modules = [] vae = res.pop('VAE', None) # old form if vae: modules = [vae] @@ -439,6 +440,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model break if not added: modules.append(res[key]) # so it shows in the override section (consistent with checkpoint and old vae) + elif key.startswith('Hires Module '): + for knownmodule in main_entry.module_list.keys(): + filename, _ = os.path.splitext(knownmodule) + if res[key] == filename: + hr_modules.append(knownmodule) + break if modules != []: current_modules = shared.opts.forge_additional_modules @@ -449,6 +456,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if sorted(modules) != sorted(basename_modules): res['VAE/TE'] = modules + res['Hires VAE/TE'] = hr_modules + return res @@ -465,7 +474,6 @@ infotext_to_setting_name_mapping = [ ('Schedule type', 'k_sched_type'), ] """ - from ast import literal_eval def create_override_settings_dict(text_pairs): """creates processing's override_settings parameters from gradio's multiselect diff --git a/modules/processing.py b/modules/processing.py index 5f821af5..c0977b64 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -824,7 +824,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: set_config(p.override_settings, is_api=True, run_callbacks=False, save_config=False) # load/reload model and manage prompt cache as needed - manage_model_and_prompt_cache(p) + if p.highresfix_quick == True: + # avoid model load here, as it could be redundant + pass + else: + manage_model_and_prompt_cache(p) if p.scripts is not None: p.scripts.before_process(p) @@ -1177,6 +1181,7 @@ def old_hires_fix_first_pass_dimensions(width, height): @dataclass(repr=False) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): enable_hr: bool = False + highresfix_quick: bool = False denoising_strength: float = 0.75 firstphase_width: int = 0 firstphase_height: int = 0 @@ -1186,6 +1191,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_resize_x: int = 0 hr_resize_y: int = 0 hr_checkpoint_name: str = None + hr_additional_modules: list = field(default_factory=list) hr_sampler_name: str = None hr_scheduler: str = None hr_prompt: str = '' @@ -1275,6 +1281,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title + if isinstance(self.hr_additional_modules, list) and len(self.hr_additional_modules) > 0: + if 'Use same choices' not in self.hr_additional_modules: + for i, m in enumerate(self.hr_additional_modules): + self.extra_generation_params[f'Hires Module {i+1}'] = os.path.splitext(os.path.basename(m))[0] + if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name @@ -1381,14 +1392,27 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = None with sd_models.SkipWritingToConfig(): + fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint') + fp_additional_modules = getattr(shared.opts, 'forge_additional_modules') + + reload = False + if 'Use same choices' not in self.hr_additional_modules: + if sorted(self.hr_additional_modules) != sorted(fp_additional_modules): + main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False) + reload = True + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': - firstpass_checkpoint = getattr(shared.opts, 'sd_model_checkpoint') - if firstpass_checkpoint != self.hr_checkpoint_name: - try: - main_entry.checkpoint_change(self.hr_checkpoint_name, save=False) - sd_models.forge_model_reload(); - finally: - main_entry.checkpoint_change(firstpass_checkpoint, save=False) + if self.hr_checkpoint_name != fp_checkpoint: + main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False) + reload = True + + if reload: + try: + main_entry.refresh_model_loading_parameters() + sd_models.forge_model_reload() + finally: + main_entry.modules_change(fp_additional_modules, save=False, refresh=False) + main_entry.checkpoint_change(fp_checkpoint, save=False) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) @@ -1618,6 +1642,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): force_task_id: str = None hr_distilled_cfg: float = 3.5 # needed here for cached_params + highresfix_quick: bool = False image_mask: Any = field(default=None, init=False) diff --git a/modules/txt2img.py b/modules/txt2img.py index 3c6d373f..efc0d402 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -12,7 +12,7 @@ import gradio as gr from modules_forge import main_thread -def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, distilled_cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, hr_cfg: float, hr_distilled_cfg: float, override_settings_texts, *args, force_enable_hr=False): +def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, distilled_cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_additional_modules: list, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, hr_cfg: float, hr_distilled_cfg: float, override_settings_texts, *args, force_enable_hr=False): override_settings = create_override_settings_dict(override_settings_texts) if force_enable_hr: @@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, + hr_additional_modules=hr_additional_modules, hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name, hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler, hr_prompt=hr_prompt, @@ -92,7 +93,7 @@ def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery p.extra_generation_params['Original Size'] = f'{args[8]}x{args[7]}' p.override_settings['save_images_before_highres_fix'] = False - p.override_settings['sd_model_checkpoint'] = p.hr_checkpoint_name + p.highresfix_quick = True with closing(p): processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) diff --git a/modules/ui.py b/modules/ui.py index 247a0af1..b65009a0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -347,10 +347,31 @@ def create_ui(): hr_distilled_cfg = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label="Hires Distilled CFG Scale", value=3.5, elem_id="txt2img_hr_distilled_cfg") hr_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.1, label="Hires CFG Scale", value=7.0, elem_id="txt2img_hr_cfg") - with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") - create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_checkpoint_container: + hr_checkpoint_name = gr.Dropdown(label='Hires Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint", scale=2) + hr_checkpoint_refresh = ToolButton(value=refresh_symbol) + + def get_additional_modules(): + modules_list = ['Use same choices'] + if main_entry.module_list == {}: + _, modules = main_entry.refresh_models() + modules_list += list(modules) + else: + modules_list += list(main_entry.module_list.keys()) + return modules_list + + modules_list = get_additional_modules() + + def refresh_model_and_modules(): + modules_list = get_additional_modules() + return gr.update(choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)), gr.update(choices=modules_list) + + hr_additional_modules = gr.Dropdown(label='Hires VAE / Text Encoder', elem_id="hr_vae_te", choices=modules_list, value=["Use same choices"], multiselect=True, scale=3) + + hr_checkpoint_refresh.click(fn=refresh_model_and_modules, outputs=[hr_checkpoint_name, hr_additional_modules], show_progress=False) + + with FormRow(elem_id="txt2img_hires_fix_row3b", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_sampler_container: hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler") @@ -421,6 +442,7 @@ def create_ui(): hr_resize_x, hr_resize_y, hr_checkpoint_name, + hr_additional_modules, hr_sampler_name, hr_scheduler, hr_prompt, @@ -489,6 +511,7 @@ def create_ui(): PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), + PasteField(hr_additional_modules, "Hires VAE/TE", api="hr_additional_modules"), PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"), PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"), PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),