restore refiner, now uses new mode loading functions. No significant changes otherwise.
This commit is contained in:
DenOfEquity
2024-11-11 11:20:48 +00:00
committed by GitHub
parent b17d43f706
commit 98e0adcc78
3 changed files with 73 additions and 72 deletions

View File

@@ -8,7 +8,7 @@ from modules.shared import opts, state
from backend.sampling.sampling_function import sampling_prepare, sampling_cleanup
from modules import extra_networks
import k_diffusion.sampling
from modules_forge import main_entry
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -161,45 +161,51 @@ replace_torchsde_browinan()
def apply_refiner(cfg_denoiser, x):
# completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
# refiner_switch_at = cfg_denoiser.p.refiner_switch_at
# refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
#
# if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
# return False
#
# if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
# return False
#
# if getattr(cfg_denoiser.p, "enable_hr", False):
# is_second_pass = cfg_denoiser.p.is_hr_pass
#
# if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
# return False
#
# if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
# return False
#
# if opts.hires_fix_refiner_pass != "second pass":
# cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
#
# cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
# cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
#
# sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
#
# with sd_models.SkipWritingToConfig():
# sd_models.reload_model_weights(info=refiner_checkpoint_info)
#
# if not cfg_denoiser.p.disable_extra_networks:
# extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
#
# cfg_denoiser.p.setup_conds()
# cfg_denoiser.update_inner_model()
#
# sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
# return True
pass
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
if getattr(cfg_denoiser.p, "enable_hr", False):
is_second_pass = cfg_denoiser.p.is_hr_pass
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
return False
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
return False
if opts.hires_fix_refiner_pass != "second pass":
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
sampling_cleanup(sd_models.model_data.get_sd_model().forge_objects.unet)
with sd_models.SkipWritingToConfig():
fp_checkpoint = getattr(shared.opts, 'sd_model_checkpoint')
checkpoint_changed = main_entry.checkpoint_change(refiner_checkpoint_info.short_title, save=False, refresh=False)
if checkpoint_changed:
try:
main_entry.refresh_model_loading_parameters()
sd_models.forge_model_reload()
finally:
main_entry.checkpoint_change(fp_checkpoint, save=False, refresh=True)
if not cfg_denoiser.p.disable_extra_networks:
extra_networks.activate(cfg_denoiser.p, cfg_denoiser.p.extra_network_data)
cfg_denoiser.p.setup_conds()
cfg_denoiser.update_inner_model()
sampling_prepare(sd_models.model_data.get_sd_model().forge_objects.unet, x=x)
return True
class TorchHijack: