Change inpainting mask to zero out on latents instead of image for inpaint area.

This commit is contained in:
Jaret Burkett
2025-03-24 14:16:52 -06:00
parent 71d7a52146
commit 6021a3dbc0
3 changed files with 21 additions and 15 deletions

View File

@@ -579,23 +579,25 @@ class CustomAdapter(torch.nn.Module):
# currently 0-1, we need rgb to be -1 to 1 before encoding with the vae
inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype)
inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :]
inpainting_tensor_rgb = inpainting_tensor_rgba[:, :3, :, :]
# we need to make sure the inpaint area is black multiply the rgb channels by the mask
inpainting_tensor_rgb = inpainting_tensor_rgb * inpainting_tensor_mask
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
if inpainting_tensor_rgb.shape[2] != batch.tensor.shape[2] or inpainting_tensor_rgb.shape[3] != batch.tensor.shape[3]:
inpainting_tensor_rgb = F.interpolate(inpainting_tensor_rgb, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
# scale to -1 to 1
inpainting_tensor_rgb = inpainting_tensor_rgb * 2 - 1
# encode it
inpainting_latent = sd.encode_images(inpainting_tensor_rgb).to(latents.device, latents.dtype)
# # use our batch latents so we cna avoid ancoding again
inpainting_latent = batch.latents
# resize the mask to match the new encoded size
inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear')
inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype)
do_mask_invert = False
if self.config.invert_inpaint_mask_chance > 0.0:
do_mask_invert = random.random() < self.config.invert_inpaint_mask_chance
if do_mask_invert:
# invert the mask
inpainting_tensor_mask = 1 - inpainting_tensor_mask
# mask out the inpainting area, it is currently 0 for inpaint area, and 1 for keep area
# we are zeroing our the latents in the inpaint area not on the pixel space.
inpainting_latent = inpainting_latent * inpainting_tensor_mask
# mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it.
inpainting_tensor_mask = 1 - inpainting_tensor_mask
# leave the mask as 0-1 and concat on channel of latents