From 98e0adcc78c1641cc9352bfdad8843bfb49cfc19 Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:20:48 +0000 Subject: [PATCH] Refiner (#2192) restore refiner, now uses new mode loading functions. No significant changes otherwise. --- modules/processing_scripts/refiner.py | 53 ++++++++--------- modules/sd_samplers_cfg_denoiser.py | 6 +- modules/sd_samplers_common.py | 86 ++++++++++++++------------- 3 files changed, 73 insertions(+), 72 deletions(-) diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index b6e70464..6e520093 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -21,34 +21,29 @@ class ScriptRefiner(scripts.ScriptBuiltinUI): def ui(self, is_img2img): with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: - gr.Markdown('Refiner is currently under maintenance and unavailable. Sorry for the inconvenience.') + with gr.Row(): + refiner_checkpoint = gr.Dropdown(label='Checkpoint', info='(use model of same architecture)', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles(use_short=True)], value='', tooltip="switch to another model in the middle of generation") + create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles(use_short=True)}, self.elem_id("checkpoint_refresh")) + + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + + def lookup_checkpoint(title): + info = sd_models.get_closet_checkpoint_match(title) + return None if info is None else info.short_title + + self.infotext_fields = [ + PasteField(enable_refiner, lambda d: 'Refiner' in d), + PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), + PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), + ] - # - # with gr.Row(): - # refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation", interactive=False, visible=False) - # # create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) - # - # refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation", interactive=False, visible=False) - # - # def lookup_checkpoint(title): - # info = sd_models.get_closet_checkpoint_match(title) - # return None if info is None else info.title - # - # self.infotext_fields = [ - # PasteField(enable_refiner, lambda d: 'Refiner' in d), - # PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), - # PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), - # ] + return enable_refiner, refiner_checkpoint, refiner_switch_at - return [enable_refiner] # , refiner_checkpoint, refiner_switch_at - - def setup(self, p, enable_refiner): - pass - # # the actual implementation is in sd_samplers_common.py, apply_refiner - # - # if not enable_refiner or refiner_checkpoint in (None, "", "None"): - # p.refiner_checkpoint = None - # p.refiner_switch_at = None - # else: - # p.refiner_checkpoint = refiner_checkpoint - # p.refiner_switch_at = refiner_switch_at + def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): + # the actual implementation is in sd_samplers_common.py, apply_refiner + if not enable_refiner or refiner_checkpoint in (None, "", "None"): + p.refiner_checkpoint = None + p.refiner_switch_at = None + else: + p.refiner_checkpoint = refiner_checkpoint + p.refiner_switch_at = refiner_switch_at diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 95f82402..82b41ea4 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -168,9 +168,9 @@ class CFGDenoiser(torch.nn.Module): x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None]) sigma = real_sigma - # if sd_samplers_common.apply_refiner(self, x): - # cond = self.sampler.sampler_extra_args['cond'] - # uncond = self.sampler.sampler_extra_args['uncond'] + if sd_samplers_common.apply_refiner(self, x): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index e64fdc8b..03858cdb 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -8,7 +8,7 @@ from modules.shared import opts, state from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup from modules import extra_networks import k_diffusion.sampling - +from modules_forge import main_entry SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -161,45 +161,51 @@ replace_torchsde_browinan() def apply_refiner(cfg_denoiser, x): - # completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps - # refiner_switch_at = cfg_denoiser.p.refiner_switch_at - # refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info - # - # if refiner_switch_at is not None and completed_ratio < refiner_switch_at: - # return False - # - # if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: - # return False - # - # if getattr(cfg_denoiser.p, "enable_hr", False): - # is_second_pass = cfg_denoiser.p.is_hr_pass - # - # if opts.hires_fix_refiner_pass == "first pass" and is_second_pass: - # return False - # - # if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass: - # return False - # - # if opts.hires_fix_refiner_pass != "second pass": - # cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass - # - # cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title - # cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at - # - # sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet) - # - # with sd_models.SkipWritingToConfig(): - # sd_models.reload_model_weights(info=refiner_checkpoint_info) - # - # if not cfg_denoiser.p.disable_extra_networks: - # extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data) - # - # cfg_denoiser.p.setup_conds() - # cfg_denoiser.update_inner_model() - # - # sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x) - # return True - pass + completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + refiner_switch_at = cfg_denoiser.p.refiner_switch_at + refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info + + if refiner_switch_at is not None and completed_ratio < refiner_switch_at: + return False + + if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: + return False + + if getattr(cfg_denoiser.p, "enable_hr", False): + is_second_pass = cfg_denoiser.p.is_hr_pass + + if opts.hires_fix_refiner_pass == "first pass" and is_second_pass: + return False + + if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass: + return False + + if opts.hires_fix_refiner_pass != "second pass": + cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass + + cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at + + sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet) + + with sd_models.SkipWritingToConfig(): + fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint') + checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False) + if checkpoint_changed: + try: + main_entry.refresh_model_loading_parameters() + sd_models.forge_model_reload() + finally: + main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True) + + if not cfg_denoiser.p.disable_extra_networks: + extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data) + + cfg_denoiser.p.setup_conds() + cfg_denoiser.update_inner_model() + + sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x) + return True class TorchHijack: