Hires additional modules (#2116)

adds selection of none/same/different modules for hiresfix
('Use same choices' default option has priority over other selections made at same time.)
includes saving/loading from infotext
This commit is contained in:
DenOfEquity
2024-10-21 11:03:12 +01:00
committed by GitHub
parent edc46380cc
commit aaa2fe761b
4 changed files with 71 additions and 14 deletions

View File

@@ -824,7 +824,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
set_config(p.override_settings, is_api=True, run_callbacks=False, save_config=False)
# load/reload model and manage prompt cache as needed
manage_model_and_prompt_cache(p)
if p.highresfix_quick == True:
# avoid model load here, as it could be redundant
pass
else:
manage_model_and_prompt_cache(p)
if p.scripts is not None:
p.scripts.before_process(p)
@@ -1177,6 +1181,7 @@ def old_hires_fix_first_pass_dimensions(width, height):
@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
enable_hr: bool = False
highresfix_quick: bool = False
denoising_strength: float = 0.75
firstphase_width: int = 0
firstphase_height: int = 0
@@ -1186,6 +1191,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_resize_x: int = 0
hr_resize_y: int = 0
hr_checkpoint_name: str = None
hr_additional_modules: list = field(default_factory=list)
hr_sampler_name: str = None
hr_scheduler: str = None
hr_prompt: str = ''
@@ -1275,6 +1281,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
if isinstance(self.hr_additional_modules, list) and len(self.hr_additional_modules) > 0:
if 'Use same choices' not in self.hr_additional_modules:
for i, m in enumerate(self.hr_additional_modules):
self.extra_generation_params[f'Hires Module {i+1}'] = os.path.splitext(os.path.basename(m))[0]
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -1381,14 +1392,27 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = None
with sd_models.SkipWritingToConfig():
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
fp_additional_modules = getattr(shared.opts, 'forge_additional_modules')
reload = False
if 'Use same choices' not in self.hr_additional_modules:
if sorted(self.hr_additional_modules) != sorted(fp_additional_modules):
main_entry.modules_change(self.hr_additional_modules, save=False, refresh=False)
reload = True
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
firstpass_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
if firstpass_checkpoint != self.hr_checkpoint_name:
try:
main_entry.checkpoint_change(self.hr_checkpoint_name, save=False)
sd_models.forge_model_reload();
finally:
main_entry.checkpoint_change(firstpass_checkpoint, save=False)
if self.hr_checkpoint_name != fp_checkpoint:
main_entry.checkpoint_change(self.hr_checkpoint_name, save=False, refresh=False)
reload = True
if reload:
try:
main_entry.refresh_model_loading_parameters()
sd_models.forge_model_reload()
finally:
main_entry.modules_change(fp_additional_modules, save=False, refresh=False)
main_entry.checkpoint_change(fp_checkpoint, save=False)
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
@@ -1618,6 +1642,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
force_task_id: str = None
hr_distilled_cfg: float = 3.5 # needed here for cached_params
highresfix_quick: bool = False
image_mask: Any = field(default=None, init=False)