Update preprocessor_inpaint.py

This commit is contained in:
lllyasviel
2024-01-30 14:12:27 -08:00
parent 1b507433aa
commit b4dafc07b4

View File

@@ -20,8 +20,6 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
self.name = 'inpaint_only'
self.image = None
self.mask = None
self.latent_image = None
self.latent_mask = None
def process_before_every_sampling(self, process, cond, *args, **kwargs):
self.image = kwargs['cond_before_inpaint_fix'][:, 0:3]
@@ -30,14 +28,25 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
vae = process.sd_model.forge_objects.vae
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
self.latent_image = vae.encode(self.image.movedim(1, -1))
B, C, H, W = self.latent_image.shape
latent_image = vae.encode(self.image.movedim(1, -1))
latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image)
B, C, H, W = latent_image.shape
latent_mask = self.mask
latent_mask = torch.nn.functional.interpolate(latent_mask, size=(H * 8, W * 8), mode="bilinear").round()
latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(self.latent_image)
self.latent_mask = latent_mask
latent_mask = torch.nn.functional.max_pool2d(latent_mask, (8, 8)).round().to(latent_image)
unet = process.sd_model.forge_objects.unet.clone()
def post_cfg(args):
denoised = args['denoised']
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
return denoised
unet.set_model_sampler_post_cfg_function(post_cfg)
process.sd_model.forge_objects.unet = unet
return
def process_after_every_sampling(self, process, params, *args, **kwargs):