From 88f6df4dcd4d12475ef375eed47160f80efadc1a Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 5 Feb 2024 14:48:43 -0800 Subject: [PATCH] add_sampler_pre_cfg_function --- .../scripts/preprocessor_inpaint.py | 6 ++++++ ldm_patched/modules/samplers.py | 3 +++ modules_forge/unet_patcher.py | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py b/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py index 1ca70fcb..d91b06a9 100644 --- a/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py +++ b/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py @@ -55,11 +55,17 @@ class PreprocessorInpaintOnly(PreprocessorInpaint): unet = process.sd_model.forge_objects.unet.clone() + def pre_cfg(model, c, uc, x, timestep, model_options): + noisy_latent = latent_image.to(x) + timestep.to(x) * torch.randn_like(latent_image).to(x) + x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x)) + return model, c, uc, x, timestep, model_options + def post_cfg(args): denoised = args['denoised'] denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised)) return denoised + unet.add_sampler_pre_cfg_function(pre_cfg) unet.set_model_sampler_post_cfg_function(post_cfg) process.sd_model.forge_objects.unet = unet diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 4c12380a..866ec6cd 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -276,6 +276,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option else: uncond_ = uncond + for fn in model_options.get("sampler_pre_cfg_function", []): + model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options) + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 4c00fa84..a46c03d2 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher): self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) return + def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False): + self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness) + return + def set_memory_peak_estimation_modifier(self, modifier): self.model_options['memory_peak_estimation_modifier'] = modifier return