diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 974c01ff..3fe61027 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -251,6 +251,7 @@ class AdapterConfig: # decimal for how often the control is dropped out and replaced with noise 1.0 is 100% self.control_image_dropout: float = kwargs.get('control_image_dropout', 0.0) self.has_inpainting_input: bool = kwargs.get('has_inpainting_input', False) + self.invert_inpaint_mask_chance: float = kwargs.get('invert_inpaint_mask_chance', 0.0) class EmbeddingConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 6dac3568..bc6453f4 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -579,23 +579,25 @@ class CustomAdapter(torch.nn.Module): # currently 0-1, we need rgb to be -1 to 1 before encoding with the vae inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype) inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :] - inpainting_tensor_rgb = inpainting_tensor_rgba[:, :3, :, :] - # we need to make sure the inpaint area is black multiply the rgb channels by the mask - inpainting_tensor_rgb = inpainting_tensor_rgb * inpainting_tensor_mask - # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it - if inpainting_tensor_rgb.shape[2] != batch.tensor.shape[2] or inpainting_tensor_rgb.shape[3] != batch.tensor.shape[3]: - inpainting_tensor_rgb = F.interpolate(inpainting_tensor_rgb, size=(batch.tensor.shape[2], batch.tensor.shape[3]), mode='bilinear') - - # scale to -1 to 1 - inpainting_tensor_rgb = inpainting_tensor_rgb * 2 - 1 - - # encode it - inpainting_latent = sd.encode_images(inpainting_tensor_rgb).to(latents.device, latents.dtype) + # # use our batch latents so we cna avoid ancoding again + inpainting_latent = batch.latents # resize the mask to match the new encoded size inpainting_tensor_mask = F.interpolate(inpainting_tensor_mask, size=(inpainting_latent.shape[2], inpainting_latent.shape[3]), mode='bilinear') inpainting_tensor_mask = inpainting_tensor_mask.to(latents.device, latents.dtype) + + do_mask_invert = False + if self.config.invert_inpaint_mask_chance > 0.0: + do_mask_invert = random.random() < self.config.invert_inpaint_mask_chance + if do_mask_invert: + # invert the mask + inpainting_tensor_mask = 1 - inpainting_tensor_mask + + # mask out the inpainting area, it is currently 0 for inpaint area, and 1 for keep area + # we are zeroing our the latents in the inpaint area not on the pixel space. + inpainting_latent = inpainting_latent * inpainting_tensor_mask + # mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it. inpainting_tensor_mask = 1 - inpainting_tensor_mask # leave the mask as 0-1 and concat on channel of latents diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index 62fd8710..3ccf6b3f 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -1609,8 +1609,8 @@ class FluxAdvancedControlPipeline(FluxControlPipeline): mask = control_img_array[:, :, 3:4] # scale it to 0 - 1 mask = mask / 255.0 - # multiply rgb by mask - control_img_array = control_img_array[:, :, :3] * mask + # control image ideally would be a full image here + control_img_array = control_img_array[:, :, :3] control_image = Image.fromarray(control_img_array.astype(np.uint8)) control_image = self.prepare_image( @@ -1636,7 +1636,10 @@ class FluxAdvancedControlPipeline(FluxControlPipeline): # resize mask to match control image mask = F.interpolate(mask, size=(control_image.shape[2], control_image.shape[3]), mode="bilinear", align_corners=False) mask = mask.to(device) - # invert mask + # apply the mask to the control image so the inpaint latent area is 0 + # mask is currently 0 for inpaint area and 1 for image area + control_image = control_image * mask + # invert mask so it is 1 for inpaint area and 0 for image area mask = 1 - mask control_image = torch.cat([control_image, mask], dim=1) num_control_channels += 1