mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-22 07:19:21 +00:00
Refiner (#2192)
restore refiner, now uses new mode loading functions. No significant changes otherwise.
This commit is contained in:
@@ -8,7 +8,7 @@ from modules.shared import opts, state
|
||||
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
|
||||
from modules import extra_networks
|
||||
import k_diffusion.sampling
|
||||
|
||||
from modules_forge import main_entry
|
||||
|
||||
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
@@ -161,45 +161,51 @@ replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(cfg_denoiser, x):
|
||||
# completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
# refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
# refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
#
|
||||
# if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
# return False
|
||||
#
|
||||
# if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
# return False
|
||||
#
|
||||
# if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||
# is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||
#
|
||||
# if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||
# return False
|
||||
#
|
||||
# if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||
# return False
|
||||
#
|
||||
# if opts.hires_fix_refiner_pass != "second pass":
|
||||
# cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||
#
|
||||
# cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
# cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
#
|
||||
# sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
#
|
||||
# with sd_models.SkipWritingToConfig():
|
||||
# sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
#
|
||||
# if not cfg_denoiser.p.disable_extra_networks:
|
||||
# extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||
#
|
||||
# cfg_denoiser.p.setup_conds()
|
||||
# cfg_denoiser.update_inner_model()
|
||||
#
|
||||
# sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
# return True
|
||||
pass
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
return False
|
||||
|
||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
return False
|
||||
|
||||
if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||
is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||
|
||||
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||
return False
|
||||
|
||||
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||
return False
|
||||
|
||||
if opts.hires_fix_refiner_pass != "second pass":
|
||||
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
|
||||
checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False)
|
||||
if checkpoint_changed:
|
||||
try:
|
||||
main_entry.refresh_model_loading_parameters()
|
||||
sd_models.forge_model_reload()
|
||||
finally:
|
||||
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True)
|
||||
|
||||
if not cfg_denoiser.p.disable_extra_networks:
|
||||
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
|
||||
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
|
||||
return True
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
|
||||
Reference in New Issue
Block a user