mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added mask diffirential mask dialation for flex2. Handle video for the i2v adapter
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import time
|
||||
import random
|
||||
|
||||
|
||||
def generate_random_mask(
|
||||
@@ -173,7 +174,7 @@ def get_gaussian_kernel(kernel_size=5, sigma=1.0):
|
||||
return gaussian
|
||||
|
||||
|
||||
def save_masks_as_images(masks, output_dir="output"):
|
||||
def save_masks_as_images(masks, suffix="", output_dir="output"):
|
||||
"""
|
||||
Save generated masks as RGB JPG images using PIL.
|
||||
"""
|
||||
@@ -192,7 +193,65 @@ def save_masks_as_images(masks, output_dir="output"):
|
||||
|
||||
# Convert to PIL Image and save
|
||||
img = Image.fromarray(rgb_mask)
|
||||
img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95)
|
||||
img.save(os.path.join(output_dir, f"mask_{i:03d}{suffix}.jpg"), quality=95)
|
||||
|
||||
|
||||
def random_dialate_mask(mask, max_percent=0.05):
|
||||
"""
|
||||
Randomly dialates a binary mask with a kernel of random size.
|
||||
|
||||
Args:
|
||||
mask (torch.Tensor): Input mask of shape [batch_size, channels, height, width]
|
||||
max_percent (float): Maximum kernel size as a percentage of the mask size
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Dialated mask with the same shape as input
|
||||
"""
|
||||
|
||||
size = mask.shape[-1]
|
||||
max_size = int(size * max_percent)
|
||||
|
||||
# Handle case where max_size is too small
|
||||
if max_size < 3:
|
||||
max_size = 3
|
||||
|
||||
batch_chunks = torch.chunk(mask, mask.shape[0], dim=0)
|
||||
out_chunks = []
|
||||
|
||||
for i in range(len(batch_chunks)):
|
||||
chunk = batch_chunks[i]
|
||||
|
||||
# Ensure kernel size is odd for proper padding
|
||||
kernel_size = np.random.randint(1, max_size)
|
||||
|
||||
# If kernel_size is less than 2, keep the original mask
|
||||
if kernel_size < 2:
|
||||
out_chunks.append(chunk)
|
||||
continue
|
||||
|
||||
# Make sure kernel size is odd
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
|
||||
# Create normalized dilation kernel
|
||||
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=mask.device) / (kernel_size * kernel_size)
|
||||
|
||||
# Pad the mask for convolution
|
||||
padding = kernel_size // 2
|
||||
padded_mask = F.pad(chunk, (padding, padding, padding, padding), mode='constant', value=0)
|
||||
|
||||
# Apply convolution
|
||||
dilated = F.conv2d(padded_mask, kernel)
|
||||
|
||||
# Random threshold for varied dilation effect
|
||||
threshold = np.random.uniform(0.2, 0.8)
|
||||
|
||||
# Apply threshold
|
||||
dilated = (dilated > threshold).float()
|
||||
|
||||
out_chunks.append(dilated)
|
||||
|
||||
return torch.cat(out_chunks, dim=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -216,11 +275,14 @@ if __name__ == "__main__":
|
||||
max_coverage=0.8,
|
||||
num_blobs_range=(1, 3)
|
||||
)
|
||||
dialation = random_dialate_mask(masks)
|
||||
print(f"Generated {batch_size} masks with shape: {masks.shape}")
|
||||
end = time.time()
|
||||
# print time in milliseconds
|
||||
print(f"Time taken: {(end - start)*1000:.2f} ms")
|
||||
|
||||
print(f"Saving masks to 'output' directory...")
|
||||
save_masks_as_images(masks)
|
||||
save_masks_as_images(dialation, suffix="_dilated" )
|
||||
|
||||
print("Done!")
|
||||
|
||||
Reference in New Issue
Block a user