diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 47d8f644..60aed8e0 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -95,7 +95,7 @@ class CFGDenoiser(torch.nn.Module): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - if sd_samplers_common.apply_refiner(self): + if sd_samplers_common.apply_refiner(self, x): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7d5b360b..2267db40 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -5,6 +5,7 @@ import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules.shared import opts, state +from ldm_patched.modules import model_management import k_diffusion.sampling @@ -151,7 +152,7 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() -def apply_refiner(cfg_denoiser): +def apply_refiner(cfg_denoiser, x): completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info @@ -180,6 +181,14 @@ def apply_refiner(cfg_denoiser): with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=refiner_checkpoint_info) + refiner = sd_models.model_data.get_sd_model() + + inference_memory = 0 + unet_patcher = refiner.unet_patcher + model_management.load_models_gpu( + [unet_patcher], + unet_patcher.memory_required([x.shape[0]] + list(x.shape[1:])) + inference_memory) + devices.torch_gc() cfg_denoiser.p.setup_conds() cfg_denoiser.update_inner_model()