add_sampler_pre_cfg_function

This commit is contained in:
lllyasviel
2024-02-05 14:48:43 -08:00
parent 8f86e66e5c
commit 88f6df4dcd
3 changed files with 13 additions and 0 deletions

View File

@@ -55,11 +55,17 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
unet = process.sd_model.forge_objects.unet.clone()
def pre_cfg(model, c, uc, x, timestep, model_options):
noisy_latent = latent_image.to(x) + timestep.to(x) * torch.randn_like(latent_image).to(x)
x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x))
return model, c, uc, x, timestep, model_options
def post_cfg(args):
denoised = args['denoised']
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
return denoised
unet.add_sampler_pre_cfg_function(pre_cfg)
unet.set_model_sampler_post_cfg_function(post_cfg)
process.sd_model.forge_objects.unet = unet