mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
i
This commit is contained in:
@@ -95,7 +95,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
if sd_samplers_common.apply_refiner(self):
|
if sd_samplers_common.apply_refiner(self, x):
|
||||||
cond = self.sampler.sampler_extra_args['cond']
|
cond = self.sampler.sampler_extra_args['cond']
|
||||||
uncond = self.sampler.sampler_extra_args['uncond']
|
uncond = self.sampler.sampler_extra_args['uncond']
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
|
from ldm_patched.modules import model_management
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
|
|
||||||
|
|
||||||
@@ -151,7 +152,7 @@ def replace_torchsde_browinan():
|
|||||||
replace_torchsde_browinan()
|
replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
def apply_refiner(cfg_denoiser):
|
def apply_refiner(cfg_denoiser, x):
|
||||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
@@ -180,6 +181,14 @@ def apply_refiner(cfg_denoiser):
|
|||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||||
|
|
||||||
|
refiner = sd_models.model_data.get_sd_model()
|
||||||
|
|
||||||
|
inference_memory = 0
|
||||||
|
unet_patcher = refiner.unet_patcher
|
||||||
|
model_management.load_models_gpu(
|
||||||
|
[unet_patcher],
|
||||||
|
unet_patcher.memory_required([x.shape[0]] + list(x.shape[1:])) + inference_memory)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
cfg_denoiser.p.setup_conds()
|
cfg_denoiser.p.setup_conds()
|
||||||
cfg_denoiser.update_inner_model()
|
cfg_denoiser.update_inner_model()
|
||||||
|
|||||||
Reference in New Issue
Block a user