mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Change inpainting mask to zero out on latents instead of image for inpaint area.
This commit is contained in:
@@ -251,6 +251,7 @@ class AdapterConfig:
|
||||
# decimal for how often the control is dropped out and replaced with noise 1.0 is 100%
|
||||
self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0)
|
||||
self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False)
|
||||
self.invert_inpaint_mask_chance: float = kwargs.get('invert_inpaint_mask_chance', 0.0)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
@@ -579,23 +579,25 @@ class CustomAdapter(torch.nn.Module):
|
||||
# currently 0-1, we need rgb to be -1 to 1 before encoding with the vae
|
||||
inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype)
|
||||
inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :]
|
||||
inpainting_tensor_rgb = inpainting_tensor_rgba[:, :3, :, :]
|
||||
# we need to make sure the inpaint area is black multiply the rgb channels by the mask
|
||||
inpainting_tensor_rgb = inpainting_tensor_rgb * inpainting_tensor_mask
|
||||
|
||||
# if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it
|
||||
if inpainting_tensor_rgb.shape[2] != batch.tensor.shape[2] or inpainting_tensor_rgb.shape[3] != batch.tensor.shape[3]:
|
||||
inpainting_tensor_rgb = F.interpolate(inpainting_tensor_rgb, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear')
|
||||
|
||||
# scale to -1 to 1
|
||||
inpainting_tensor_rgb = inpainting_tensor_rgb * 2 - 1
|
||||
|
||||
# encode it
|
||||
inpainting_latent = sd.encode_images(inpainting_tensor_rgb).to(latents.device, latents.dtype)
|
||||
# # use our batch latents so we cna avoid ancoding again
|
||||
inpainting_latent = batch.latents
|
||||
|
||||
# resize the mask to match the new encoded size
|
||||
inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear')
|
||||
inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype)
|
||||
|
||||
do_mask_invert = False
|
||||
if self.config.invert_inpaint_mask_chance > 0.0:
|
||||
do_mask_invert = random.random() < self.config.invert_inpaint_mask_chance
|
||||
if do_mask_invert:
|
||||
# invert the mask
|
||||
inpainting_tensor_mask = 1 - inpainting_tensor_mask
|
||||
|
||||
# mask out the inpainting area, it is currently 0 for inpaint area, and 1 for keep area
|
||||
# we are zeroing our the latents in the inpaint area not on the pixel space.
|
||||
inpainting_latent = inpainting_latent * inpainting_tensor_mask
|
||||
|
||||
# mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it.
|
||||
inpainting_tensor_mask = 1 - inpainting_tensor_mask
|
||||
# leave the mask as 0-1 and concat on channel of latents
|
||||
|
||||
@@ -1609,8 +1609,8 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
mask = control_img_array[:, :, 3:4]
|
||||
# scale it to 0 - 1
|
||||
mask = mask / 255.0
|
||||
# multiply rgb by mask
|
||||
control_img_array = control_img_array[:, :, :3] * mask
|
||||
# control image ideally would be a full image here
|
||||
control_img_array = control_img_array[:, :, :3]
|
||||
control_image = Image.fromarray(control_img_array.astype(np.uint8))
|
||||
|
||||
control_image = self.prepare_image(
|
||||
@@ -1636,7 +1636,10 @@ class FluxAdvancedControlPipeline(FluxControlPipeline):
|
||||
# resize mask to match control image
|
||||
mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False)
|
||||
mask = mask.to(device)
|
||||
# invert mask
|
||||
# apply the mask to the control image so the inpaint latent area is 0
|
||||
# mask is currently 0 for inpaint area and 1 for image area
|
||||
control_image = control_image * mask
|
||||
# invert mask so it is 1 for inpaint area and 0 for image area
|
||||
mask = 1 - mask
|
||||
control_image = torch.cat([control_image, mask], dim=1)
|
||||
num_control_channels += 1
|
||||
|
||||
Reference in New Issue
Block a user