mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-26 17:29:09 +00:00
Refiner (#2192)
restore refiner, now uses new mode loading functions. No significant changes otherwise.
This commit is contained in:
@@ -21,34 +21,29 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
|
|||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||||
gr.Markdown('Refiner is currently under maintenance and unavailable. Sorry for the inconvenience.')
|
with gr.Row():
|
||||||
|
refiner_checkpoint = gr.Dropdown(label='Checkpoint', info='(use model of same architecture)', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles(use_short=True)], value='', tooltip="switch to another model in the middle of generation")
|
||||||
|
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles(use_short=True)}, self.elem_id("checkpoint_refresh"))
|
||||||
|
|
||||||
|
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||||
|
|
||||||
|
def lookup_checkpoint(title):
|
||||||
|
info = sd_models.get_closet_checkpoint_match(title)
|
||||||
|
return None if info is None else info.short_title
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||||
|
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||||
|
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||||
|
]
|
||||||
|
|
||||||
#
|
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||||
# with gr.Row():
|
|
||||||
# refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=["", *sd_models.checkpoint_tiles()], value='', tooltip="switch to another model in the middle of generation", interactive=False, visible=False)
|
|
||||||
# # create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
|
||||||
#
|
|
||||||
# refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation", interactive=False, visible=False)
|
|
||||||
#
|
|
||||||
# def lookup_checkpoint(title):
|
|
||||||
# info = sd_models.get_closet_checkpoint_match(title)
|
|
||||||
# return None if info is None else info.title
|
|
||||||
#
|
|
||||||
# self.infotext_fields = [
|
|
||||||
# PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
|
||||||
# PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
|
||||||
# PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
|
||||||
# ]
|
|
||||||
|
|
||||||
return [enable_refiner] # , refiner_checkpoint, refiner_switch_at
|
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||||
|
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||||
def setup(self, p, enable_refiner):
|
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||||
pass
|
p.refiner_checkpoint = None
|
||||||
# # the actual implementation is in sd_samplers_common.py, apply_refiner
|
p.refiner_switch_at = None
|
||||||
#
|
else:
|
||||||
# if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
p.refiner_checkpoint = refiner_checkpoint
|
||||||
# p.refiner_checkpoint = None
|
p.refiner_switch_at = refiner_switch_at
|
||||||
# p.refiner_switch_at = None
|
|
||||||
# else:
|
|
||||||
# p.refiner_checkpoint = refiner_checkpoint
|
|
||||||
# p.refiner_switch_at = refiner_switch_at
|
|
||||||
|
|||||||
@@ -168,9 +168,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
|
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
|
||||||
sigma = real_sigma
|
sigma = real_sigma
|
||||||
|
|
||||||
# if sd_samplers_common.apply_refiner(self, x):
|
if sd_samplers_common.apply_refiner(self, x):
|
||||||
# cond = self.sampler.sampler_extra_args['cond']
|
cond = self.sampler.sampler_extra_args['cond']
|
||||||
# uncond = self.sampler.sampler_extra_args['uncond']
|
uncond = self.sampler.sampler_extra_args['uncond']
|
||||||
|
|
||||||
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from modules.shared import opts, state
|
|||||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||||
from modules import extra_networks
|
from modules import extra_networks
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
|
from modules_forge import main_entry
|
||||||
|
|
||||||
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
@@ -161,45 +161,51 @@ replace_torchsde_browinan()
|
|||||||
|
|
||||||
|
|
||||||
def apply_refiner(cfg_denoiser, x):
|
def apply_refiner(cfg_denoiser, x):
|
||||||
# completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||||
# refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
# refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
#
|
|
||||||
# if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||||
# return False
|
return False
|
||||||
#
|
|
||||||
# if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||||
# return False
|
return False
|
||||||
#
|
|
||||||
# if getattr(cfg_denoiser.p, "enable_hr", False):
|
if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||||
# is_second_pass = cfg_denoiser.p.is_hr_pass
|
is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||||
#
|
|
||||||
# if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||||
# return False
|
return False
|
||||||
#
|
|
||||||
# if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||||
# return False
|
return False
|
||||||
#
|
|
||||||
# if opts.hires_fix_refiner_pass != "second pass":
|
if opts.hires_fix_refiner_pass != "second pass":
|
||||||
# cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||||
#
|
|
||||||
# cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||||
# cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||||
#
|
|
||||||
# sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||||
#
|
|
||||||
# with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
# sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
|
||||||
#
|
checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False)
|
||||||
# if not cfg_denoiser.p.disable_extra_networks:
|
if checkpoint_changed:
|
||||||
# extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
try:
|
||||||
#
|
main_entry.refresh_model_loading_parameters()
|
||||||
# cfg_denoiser.p.setup_conds()
|
sd_models.forge_model_reload()
|
||||||
# cfg_denoiser.update_inner_model()
|
finally:
|
||||||
#
|
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True)
|
||||||
# sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
|
||||||
# return True
|
if not cfg_denoiser.p.disable_extra_networks:
|
||||||
pass
|
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||||
|
|
||||||
|
cfg_denoiser.p.setup_conds()
|
||||||
|
cfg_denoiser.update_inner_model()
|
||||||
|
|
||||||
|
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
class TorchHijack:
|
||||||
|
|||||||
Reference in New Issue
Block a user