diff --git a/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py b/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py index eb8b06a7..2127da4b 100644 --- a/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py +++ b/extensions-builtin/forge_preprocessor_inpaint/scripts/preprocessor_inpaint.py @@ -20,8 +20,6 @@ class PreprocessorInpaintOnly(PreprocessorInpaint): self.name = 'inpaint_only' self.image = None self.mask = None - self.latent_image = None - self.latent_mask = None def process_before_every_sampling(self, process, cond, *args, **kwargs): self.image = kwargs['cond_before_inpaint_fix'][:, 0:3] @@ -30,14 +28,25 @@ class PreprocessorInpaintOnly(PreprocessorInpaint): vae = process.sd_model.forge_objects.vae # This is a powerful VAE with integrated memory management, bf16, and tiled fallback. - self.latent_image = vae.encode(self.image.movedim(1, -1)) - B, C, H, W = self.latent_image.shape + latent_image = vae.encode(self.image.movedim(1, -1)) + latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) + + B, C, H, W = latent_image.shape latent_mask = self.mask latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round() - latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(self.latent_image) - self.latent_mask = latent_mask + latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent_image) + unet = process.sd_model.forge_objects.unet.clone() + + def post_cfg(args): + denoised = args['denoised'] + denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised)) + return denoised + + unet.set_model_sampler_post_cfg_function(post_cfg) + + process.sd_model.forge_objects.unet = unet return def process_after_every_sampling(self, process, params, *args, **kwargs):