Added mask diffirential mask dialation for flex2. Handle video for the i2v adapter

This commit is contained in:
Jaret Burkett
2025-04-10 11:50:01 -06:00
parent 9794416a5d
commit 059155174a
5 changed files with 118 additions and 3 deletions

View File

@@ -302,6 +302,16 @@ class CustomAdapter(torch.nn.Module):
# else:
raise NotImplementedError
def edit_batch_raw(self, batch: DataLoaderBatchDTO):
# happens on a raw batch before latents are created
return batch
def edit_batch_processed(self, batch: DataLoaderBatchDTO):
# happens after the latents are processed
if self.adapter_type == "i2v":
return self.i2v_adapter.edit_batch_processed(batch)
return batch
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()

View File

@@ -592,6 +592,36 @@ class I2VAdapter(torch.nn.Module):
def condition_noisy_latents(self, latents: torch.Tensor, batch:DataLoaderBatchDTO):
# todo handle start frame
return latents
def edit_batch_processed(self, batch: DataLoaderBatchDTO):
with torch.no_grad():
# we will alway get a clip image frame, if one is not passed, use image
# or if video, pull from the first frame
# edit the batch to pull the first frame out of a video if we have it
# videos come in (bs, num_frames, channels, height, width)
tensor = batch.tensor
if batch.clip_image_tensor is None:
if len(tensor.shape) == 5:
# we have a video
first_frames = tensor[:, 0, :, :, :].clone()
else:
# we have a single image
first_frames = tensor.clone()
# it is -1 to 1, change it to 0 to 1
first_frames = (first_frames + 1) / 2
# clip image tensors are preprocessed.
tensors_0_1 = first_frames.to(dtype=torch.float16)
clip_out = self.adapter_ref().clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
batch.clip_image_tensor = clip_out.to(self.device_torch)
return batch
@property
def is_active(self):

View File

@@ -4,6 +4,7 @@ import os
import torch.nn.functional as F
from PIL import Image
import time
import random
def generate_random_mask(
@@ -173,7 +174,7 @@ def get_gaussian_kernel(kernel_size=5, sigma=1.0):
return gaussian
def save_masks_as_images(masks, output_dir="output"):
def save_masks_as_images(masks, suffix="", output_dir="output"):
"""
Save generated masks as RGB JPG images using PIL.
"""
@@ -192,7 +193,65 @@ def save_masks_as_images(masks, output_dir="output"):
# Convert to PIL Image and save
img = Image.fromarray(rgb_mask)
img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95)
img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95)
def random_dialate_mask(mask, max_percent=0.05):
"""
Randomly dialates a binary mask with a kernel of random size.
Args:
mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width]
max_percent (float): Maximum kernel size as a percentage of the mask size
Returns:
torch.Tensor: Dialated mask with the same shape as input
"""
size = mask.shape[-1]
max_size = int(size * max_percent)
# Handle case where max_size is too small
if max_size < 3:
max_size = 3
batch_chunks = torch.chunk(mask, mask.shape[0], dim=0)
out_chunks = []
for i in range(len(batch_chunks)):
chunk = batch_chunks[i]
# Ensure kernel size is odd for proper padding
kernel_size = np.random.randint(1, max_size)
# If kernel_size is less than 2, keep the original mask
if kernel_size < 2:
out_chunks.append(chunk)
continue
# Make sure kernel size is odd
if kernel_size % 2 == 0:
kernel_size += 1
# Create normalized dilation kernel
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size)
# Pad the mask for convolution
padding = kernel_size // 2
padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0)
# Apply convolution
dilated = F.conv2d(padded_mask, kernel)
# Random threshold for varied dilation effect
threshold = np.random.uniform(0.2, 0.8)
# Apply threshold
dilated = (dilated > threshold).float()
out_chunks.append(dilated)
return torch.cat(out_chunks, dim=0)
if __name__ == "__main__":
@@ -216,11 +275,14 @@ if __name__ == "__main__":
max_coverage=0.8,
num_blobs_range=(1, 3)
)
dialation = random_dialate_mask(masks)
print(f"Generated {batch_size} masks with shape: {masks.shape}")
end = time.time()
# print time in milliseconds
print(f"Time taken: {(end - start)*1000:.2f} ms")
print(f"Saving masks to 'output' directory...")
save_masks_as_images(masks)
save_masks_as_images(dialation, suffix="_dilated" )
print("Done!")