Hires additional modules (#2116)

adds selection of none/same/different modules for hiresfix
('Use same choices' default option has priority over other selections made at same time.)
includes saving/loading from infotext
This commit is contained in:
DenOfEquity
2024-10-21 11:03:12 +01:00
committed by GitHub
parent edc46380cc
commit aaa2fe761b
4 changed files with 71 additions and 14 deletions

View File

@@ -424,6 +424,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
# VAE / TE
modules = []
hr_modules = []
vae = res.pop('VAE', None) # old form
if vae:
modules = [vae]
@@ -439,6 +440,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
break
if not added:
modules.append(res[key]) # so it shows in the override section (consistent with checkpoint and old vae)
elif key.startswith('Hires Module '):
for knownmodule in main_entry.module_list.keys():
filename, _ = os.path.splitext(knownmodule)
if res[key] == filename:
hr_modules.append(knownmodule)
break
if modules != []:
current_modules = shared.opts.forge_additional_modules
@@ -449,6 +456,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if sorted(modules) != sorted(basename_modules):
res['VAE/TE'] = modules
res['Hires VAE/TE'] = hr_modules
return res
@@ -465,7 +474,6 @@ infotext_to_setting_name_mapping = [
('Schedule type', 'k_sched_type'),
]
"""
from ast import literal_eval
def create_override_settings_dict(text_pairs):
"""creates processing's override_settings parameters from gradio's multiselect

View File

@@ -824,7 +824,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
set_config(p.override_settings, is_api=True, run_callbacks=False, save_config=False)
# load/reload model and manage prompt cache as needed
manage_model_and_prompt_cache(p)
if p.highresfix_quick == True:
# avoid model load here, as it could be redundant
pass
else:
manage_model_and_prompt_cache(p)
if p.scripts is not None:
p.scripts.before_process(p)
@@ -1177,6 +1181,7 @@ def old_hires_fix_first_pass_dimensions(width, height):
@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
enable_hr: bool = False
highresfix_quick: bool = False
denoising_strength: float = 0.75
firstphase_width: int = 0
firstphase_height: int = 0
@@ -1186,6 +1191,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_resize_x: int = 0
hr_resize_y: int = 0
hr_checkpoint_name: str = None
hr_additional_modules: list = field(default_factory=list)
hr_sampler_name: str = None
hr_scheduler: str = None
hr_prompt: str = ''
@@ -1275,6 +1281,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
if isinstance(self.hr_additional_modules, list) and len(self.hr_additional_modules) > 0:
if 'Use same choices' not in self.hr_additional_modules:
for i, m in enumerate(self.hr_additional_modules):
self.extra_generation_params[f'Hires Module {i+1}'] = os.path.splitext(os.path.basename(m))[0]
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -1381,14 +1392,27 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = None
with sd_models.SkipWritingToConfig():
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
fp_additional_modules = getattr(shared.opts, 'forge_additional_modules')
reload = False
if 'Use same choices' not in self.hr_additional_modules:
if sorted(self.hr_additional_modules) != sorted(fp_additional_modules):
main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False)
reload = True
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
firstpass_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
if firstpass_checkpoint != self.hr_checkpoint_name:
try:
main_entry.checkpoint_change(self.hr_checkpoint_name, save=False)
sd_models.forge_model_reload();
finally:
main_entry.checkpoint_change(firstpass_checkpoint, save=False)
if self.hr_checkpoint_name != fp_checkpoint:
main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False)
reload = True
if reload:
try:
main_entry.refresh_model_loading_parameters()
sd_models.forge_model_reload()
finally:
main_entry.modules_change(fp_additional_modules, save=False, refresh=False)
main_entry.checkpoint_change(fp_checkpoint, save=False)
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
@@ -1618,6 +1642,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
force_task_id: str = None
hr_distilled_cfg: float = 3.5 # needed here for cached_params
highresfix_quick: bool = False
image_mask: Any = field(default=None, init=False)

View File

@@ -12,7 +12,7 @@ import gradio as gr
from modules_forge import main_thread
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, distilled_cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, hr_cfg: float, hr_distilled_cfg: float, override_settings_texts, *args, force_enable_hr=False):
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, distilled_cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_additional_modules: list, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, hr_cfg: float, hr_distilled_cfg: float, override_settings_texts, *args, force_enable_hr=False):
override_settings = create_override_settings_dict(override_settings_texts)
if force_enable_hr:
@@ -38,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_additional_modules=hr_additional_modules,
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
hr_prompt=hr_prompt,
@@ -92,7 +93,7 @@ def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery
p.extra_generation_params['Original Size'] = f'{args[8]}x{args[7]}'
p.override_settings['save_images_before_highres_fix'] = False
p.override_settings['sd_model_checkpoint'] = p.hr_checkpoint_name
p.highresfix_quick = True
with closing(p):
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)

View File

@@ -347,10 +347,31 @@ def create_ui():
hr_distilled_cfg = gr.Slider(minimum=0.0, maximum=30.0, step=0.1, label="Hires Distilled CFG Scale", value=3.5, elem_id="txt2img_hr_distilled_cfg")
hr_cfg = gr.Slider(minimum=1.0, maximum=30.0, step=0.1, label="Hires CFG Scale", value=7.0, elem_id="txt2img_hr_cfg")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_checkpoint_container:
hr_checkpoint_name = gr.Dropdown(label='Hires Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint", scale=2)
hr_checkpoint_refresh = ToolButton(value=refresh_symbol)
def get_additional_modules():
modules_list = ['Use same choices']
if main_entry.module_list == {}:
_, modules = main_entry.refresh_models()
modules_list += list(modules)
else:
modules_list += list(main_entry.module_list.keys())
return modules_list
modules_list = get_additional_modules()
def refresh_model_and_modules():
modules_list = get_additional_modules()
return gr.update(choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)), gr.update(choices=modules_list)
hr_additional_modules = gr.Dropdown(label='Hires VAE / Text Encoder', elem_id="hr_vae_te", choices=modules_list, value=["Use same choices"], multiselect=True, scale=3)
hr_checkpoint_refresh.click(fn=refresh_model_and_modules, outputs=[hr_checkpoint_name, hr_additional_modules], show_progress=False)
with FormRow(elem_id="txt2img_hires_fix_row3b", variant="compact", visible=shared.opts.hires_fix_show_sampler) as hr_sampler_container:
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
@@ -421,6 +442,7 @@ def create_ui():
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
hr_additional_modules,
hr_sampler_name,
hr_scheduler,
hr_prompt,
@@ -489,6 +511,7 @@ def create_ui():
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
PasteField(hr_additional_modules, "Hires VAE/TE", api="hr_additional_modules"),
PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),