From 059155174ac065de33ab7496a1b682245eeb234e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 10 Apr 2025 11:50:01 -0600 Subject: [PATCH] Added mask diffirential mask dialation for flex2. Handle video for the i2v adapter --- extensions_built_in/flex2/flex2.py | 11 +++- extensions_built_in/sd_trainer/SDTrainer.py | 4 ++ toolkit/custom_adapter.py | 10 ++++ toolkit/models/i2v_adapter.py | 30 ++++++++++ toolkit/util/mask.py | 66 ++++++++++++++++++++- 5 files changed, 118 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/flex2/flex2.py b/extensions_built_in/flex2/flex2.py index 7bb89506..d0d73be6 100644 --- a/extensions_built_in/flex2/flex2.py +++ b/extensions_built_in/flex2/flex2.py @@ -16,7 +16,7 @@ from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guid from toolkit.dequantize import patch_dequantization_on_save from toolkit.accelerator import get_accelerator, unwrap_model from optimum.quanto import freeze, QTensor -from toolkit.util.mask import generate_random_mask +from toolkit.util.mask import generate_random_mask, random_dialate_mask from toolkit.util.quantize import quantize, get_qtype from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer from .pipeline import Flex2Pipeline @@ -77,6 +77,7 @@ class Flex2(BaseModel): self.control_dropout = model_config.model_kwargs.get('control_dropout', 0.0) self.inpaint_random_chance = model_config.model_kwargs.get('inpaint_random_chance', 0.0) self.random_blur_mask = model_config.model_kwargs.get('random_blur_mask', False) + self.random_dialate_mask = model_config.model_kwargs.get('random_dialate_mask', False) self.do_random_inpainting = model_config.model_kwargs.get('do_random_inpainting', False) # static method to get the noise scheduler @@ -446,6 +447,14 @@ class Flex2(BaseModel): # we are zeroing our the latents in the inpaint area not on the pixel space. inpainting_latent = inpainting_latent * inpainting_tensor_mask + # do the random dialation after the mask is applied so it does not match perfectly. + # this will make the model learn to prevent weird edges + if self.random_dialate_mask: + inpainting_tensor_mask = random_dialate_mask( + inpainting_tensor_mask, + max_percent=0.05 + ) + # mask needs to be 1 for inpaint area and 0 for area to leave alone. So flip it. inpainting_tensor_mask = 1 - inpainting_tensor_mask # leave the mask as 0-1 and concat on channel of latents diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 682c87f7..ecdbfbf5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -771,7 +771,11 @@ class SDTrainer(BaseSDTrainProcess): def train_single_accumulation(self, batch: DataLoaderBatchDTO): self.timer.start('preprocess_batch') + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_raw(batch) batch = self.preprocess_batch(batch) + if isinstance(self.adapter, CustomAdapter): + batch = self.adapter.edit_batch_processed(batch) dtype = get_torch_dtype(self.train_config.dtype) # sanity check if self.sd.vae.dtype != self.sd.vae_torch_dtype: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 1ac999e5..5761bf03 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -302,6 +302,16 @@ class CustomAdapter(torch.nn.Module): # else: raise NotImplementedError + def edit_batch_raw(self, batch: DataLoaderBatchDTO): + # happens on a raw batch before latents are created + return batch + + def edit_batch_processed(self, batch: DataLoaderBatchDTO): + # happens after the latents are processed + if self.adapter_type == "i2v": + return self.i2v_adapter.edit_batch_processed(batch) + return batch + def setup_clip(self): adapter_config = self.config sd = self.sd_ref() diff --git a/toolkit/models/i2v_adapter.py b/toolkit/models/i2v_adapter.py index 09f45995..dcf4c2f2 100644 --- a/toolkit/models/i2v_adapter.py +++ b/toolkit/models/i2v_adapter.py @@ -592,6 +592,36 @@ class I2VAdapter(torch.nn.Module): def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO): # todo handle start frame return latents + + def edit_batch_processed(self, batch: DataLoaderBatchDTO): + with torch.no_grad(): + # we will alway get a clip image frame, if one is not passed, use image + # or if video, pull from the first frame + # edit the batch to pull the first frame out of a video if we have it + # videos come in (bs, num_frames, channels, height, width) + tensor = batch.tensor + if batch.clip_image_tensor is None: + if len(tensor.shape) == 5: + # we have a video + first_frames = tensor[:, 0, :, :, :].clone() + else: + # we have a single image + first_frames = tensor.clone() + + # it is -1 to 1, change it to 0 to 1 + first_frames = (first_frames + 1) / 2 + + # clip image tensors are preprocessed. + tensors_0_1 = first_frames.to(dtype=torch.float16) + clip_out = self.adapter_ref().clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + + batch.clip_image_tensor = clip_out.to(self.device_torch) + return batch @property def is_active(self): diff --git a/toolkit/util/mask.py b/toolkit/util/mask.py index 65334281..1b80f1ca 100644 --- a/toolkit/util/mask.py +++ b/toolkit/util/mask.py @@ -4,6 +4,7 @@ import os import torch.nn.functional as F from PIL import Image import time +import random def generate_random_mask( @@ -173,7 +174,7 @@ def get_gaussian_kernel(kernel_size=5, sigma=1.0): return gaussian -def save_masks_as_images(masks, output_dir="output"): +def save_masks_as_images(masks, suffix="", output_dir="output"): """ Save generated masks as RGB JPG images using PIL. """ @@ -192,7 +193,65 @@ def save_masks_as_images(masks, output_dir="output"): # Convert to PIL Image and save img = Image.fromarray(rgb_mask) - img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95) + img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95) + + +def random_dialate_mask(mask, max_percent=0.05): + """ + Randomly dialates a binary mask with a kernel of random size. + + Args: + mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width] + max_percent (float): Maximum kernel size as a percentage of the mask size + + Returns: + torch.Tensor: Dialated mask with the same shape as input + """ + + size = mask.shape[-1] + max_size = int(size * max_percent) + + # Handle case where max_size is too small + if max_size < 3: + max_size = 3 + + batch_chunks = torch.chunk(mask, mask.shape[0], dim=0) + out_chunks = [] + + for i in range(len(batch_chunks)): + chunk = batch_chunks[i] + + # Ensure kernel size is odd for proper padding + kernel_size = np.random.randint(1, max_size) + + # If kernel_size is less than 2, keep the original mask + if kernel_size < 2: + out_chunks.append(chunk) + continue + + # Make sure kernel size is odd + if kernel_size % 2 == 0: + kernel_size += 1 + + # Create normalized dilation kernel + kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size) + + # Pad the mask for convolution + padding = kernel_size // 2 + padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0) + + # Apply convolution + dilated = F.conv2d(padded_mask, kernel) + + # Random threshold for varied dilation effect + threshold = np.random.uniform(0.2, 0.8) + + # Apply threshold + dilated = (dilated > threshold).float() + + out_chunks.append(dilated) + + return torch.cat(out_chunks, dim=0) if __name__ == "__main__": @@ -216,11 +275,14 @@ if __name__ == "__main__": max_coverage=0.8, num_blobs_range=(1, 3) ) + dialation = random_dialate_mask(masks) + print(f"Generated {batch_size} masks with shape: {masks.shape}") end = time.time() # print time in milliseconds print(f"Time taken: {(end - start)*1000:.2f} ms") print(f"Saving masks to 'output' directory...") save_masks_as_images(masks) + save_masks_as_images(dialation, suffix="_dilated" ) print("Done!")