mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-03 12:24:59 +00:00
227 lines
7.7 KiB
Python
227 lines
7.7 KiB
Python
import torch
|
|
import numpy as np
|
|
import os
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
import time
|
|
|
|
|
|
def generate_random_mask(
|
|
batch_size,
|
|
height=256,
|
|
width=256,
|
|
device='cuda',
|
|
min_coverage=0.2,
|
|
max_coverage=0.8,
|
|
num_blobs_range=(1, 3)
|
|
):
|
|
"""
|
|
Generate random blob masks for a batch of images.
|
|
Fast GPU version with smooth, non-circular blob shapes.
|
|
|
|
Args:
|
|
batch_size (int): Number of masks to generate
|
|
height (int): Height of the mask
|
|
width (int): Width of the mask
|
|
device (str): Device to run the computation on ('cuda' or 'cpu')
|
|
min_coverage (float): Minimum percentage of the image to be covered (0-1)
|
|
max_coverage (float): Maximum percentage of the image to be covered (0-1)
|
|
num_blobs_range (tuple): Range of number of blobs (min, max)
|
|
|
|
Returns:
|
|
torch.Tensor: Binary masks with shape (batch_size, 1, height, width)
|
|
"""
|
|
# Initialize masks on GPU
|
|
masks = torch.zeros((batch_size, 1, height, width), device=device)
|
|
|
|
# Pre-compute coordinate grid on GPU
|
|
y_indices = torch.arange(height, device=device).view(
|
|
height, 1).expand(height, width)
|
|
x_indices = torch.arange(width, device=device).view(
|
|
1, width).expand(height, width)
|
|
|
|
# Prepare gaussian kernels for smoothing
|
|
small_kernel = get_gaussian_kernel(7, 1.0).to(device)
|
|
small_kernel = small_kernel.view(1, 1, 7, 7)
|
|
|
|
large_kernel = get_gaussian_kernel(15, 2.5).to(device)
|
|
large_kernel = large_kernel.view(1, 1, 15, 15)
|
|
|
|
# Constants
|
|
max_radius = min(height, width) // 3
|
|
min_radius = min(height, width) // 8
|
|
|
|
# For each mask in the batch
|
|
for b in range(batch_size):
|
|
# Determine number of blobs for this mask
|
|
num_blobs = np.random.randint(
|
|
num_blobs_range[0], num_blobs_range[1] + 1)
|
|
|
|
# Target coverage for this mask
|
|
target_coverage = np.random.uniform(min_coverage, max_coverage)
|
|
|
|
# Initialize this mask
|
|
mask = torch.zeros(1, 1, height, width, device=device)
|
|
|
|
# Generate blobs with smoother edges
|
|
for _ in range(num_blobs):
|
|
# Create a low-frequency noise field first (for smooth organic shapes)
|
|
noise_field = torch.zeros(height, width, device=device)
|
|
|
|
# Use low-frequency sine waves to create base shape distortion
|
|
# This creates smoother warping compared to pure random noise
|
|
num_waves = np.random.randint(2, 5)
|
|
for i in range(num_waves):
|
|
freq_x = np.random.uniform(1.0, 3.0) * np.pi / width
|
|
freq_y = np.random.uniform(1.0, 3.0) * np.pi / height
|
|
phase_x = np.random.uniform(0, 2 * np.pi)
|
|
phase_y = np.random.uniform(0, 2 * np.pi)
|
|
amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5)
|
|
|
|
# Generate smooth wave patterns
|
|
wave = torch.sin(x_indices * freq_x + phase_x) * \
|
|
torch.sin(y_indices * freq_y + phase_y) * amp
|
|
noise_field += wave
|
|
|
|
# Basic ellipse parameters
|
|
center_y = np.random.randint(height//4, 3*height//4)
|
|
center_x = np.random.randint(width//4, 3*width//4)
|
|
radius = np.random.randint(min_radius, max_radius)
|
|
|
|
# Squeeze and stretch the ellipse with random scaling
|
|
scale_y = np.random.uniform(0.6, 1.4)
|
|
scale_x = np.random.uniform(0.6, 1.4)
|
|
|
|
# Random rotation
|
|
theta = np.random.uniform(0, 2 * np.pi)
|
|
cos_theta, sin_theta = np.cos(theta), np.sin(theta)
|
|
|
|
# Calculate elliptical distance field
|
|
y_scaled = (y_indices - center_y) * scale_y
|
|
x_scaled = (x_indices - center_x) * scale_x
|
|
|
|
# Apply rotation
|
|
rotated_y = y_scaled * cos_theta - x_scaled * sin_theta
|
|
rotated_x = y_scaled * sin_theta + x_scaled * cos_theta
|
|
|
|
# Compute distances
|
|
distances = torch.sqrt(rotated_y**2 + rotated_x**2)
|
|
|
|
# Apply the smooth noise field to the distance field
|
|
perturbed_distances = distances + noise_field
|
|
|
|
# Create base blob
|
|
blob = (perturbed_distances < radius).float(
|
|
).unsqueeze(0).unsqueeze(0)
|
|
|
|
# Apply strong smoothing for very smooth edges
|
|
# Double smoothing to get really organic edges
|
|
blob = F.pad(blob, (7, 7, 7, 7), mode='reflect')
|
|
blob = F.conv2d(blob, large_kernel, padding=0)
|
|
|
|
# Apply threshold to get a nice shape
|
|
rand_threshold = np.random.uniform(0.3, 0.6)
|
|
blob = (blob > rand_threshold).float()
|
|
|
|
# Apply second smoothing pass
|
|
blob = F.pad(blob, (3, 3, 3, 3), mode='reflect')
|
|
blob = F.conv2d(blob, small_kernel, padding=0)
|
|
blob = (blob > 0.5).float()
|
|
|
|
# Add to mask
|
|
mask = torch.maximum(mask, blob)
|
|
|
|
# Ensure desired coverage
|
|
current_coverage = mask.mean().item()
|
|
|
|
# Scale if needed to match target coverage
|
|
if current_coverage > 0: # Avoid division by zero
|
|
if current_coverage < target_coverage * 0.7: # Too small
|
|
# Dilate mask to increase coverage
|
|
mask = F.pad(mask, (2, 2, 2, 2), mode='reflect')
|
|
mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0)
|
|
elif current_coverage > target_coverage * 1.3: # Too large
|
|
# Erode mask to decrease coverage
|
|
mask = F.pad(mask, (1, 1, 1, 1), mode='reflect')
|
|
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0)
|
|
mask = (mask > 0.7).float()
|
|
|
|
# Final smooth and threshold
|
|
mask = F.pad(mask, (3, 3, 3, 3), mode='reflect')
|
|
mask = F.conv2d(mask, small_kernel, padding=0)
|
|
mask = (mask > 0.5).float()
|
|
|
|
# Add to batch
|
|
masks[b] = mask
|
|
|
|
return masks
|
|
|
|
|
|
def get_gaussian_kernel(kernel_size=5, sigma=1.0):
|
|
"""
|
|
Returns a 2D Gaussian kernel.
|
|
"""
|
|
# Create 1D kernels
|
|
x = torch.linspace(-sigma * 2, sigma * 2, kernel_size)
|
|
x = x.view(1, -1).repeat(kernel_size, 1)
|
|
y = x.transpose(0, 1)
|
|
|
|
# 2D Gaussian
|
|
gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
|
|
gaussian /= gaussian.sum()
|
|
|
|
return gaussian
|
|
|
|
|
|
def save_masks_as_images(masks, output_dir="output"):
|
|
"""
|
|
Save generated masks as RGB JPG images using PIL.
|
|
"""
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
batch_size = masks.shape[0]
|
|
for i in range(batch_size):
|
|
# Convert mask to numpy array
|
|
mask = masks[i, 0].cpu().numpy()
|
|
|
|
# Scale to 0-255 range and convert to uint8
|
|
mask_255 = (mask * 255).astype(np.uint8)
|
|
|
|
# Create RGB image (white mask on black background)
|
|
rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2)
|
|
|
|
# Convert to PIL Image and save
|
|
img = Image.fromarray(rgb_mask)
|
|
img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Parameters
|
|
batch_size = 20
|
|
height = 256
|
|
width = 256
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
print(f"Generating {batch_size} random blob masks on {device}...")
|
|
|
|
for i in range(5):
|
|
# time it
|
|
start = time.time()
|
|
masks = generate_random_mask(
|
|
batch_size=batch_size,
|
|
height=height,
|
|
width=width,
|
|
device=device,
|
|
min_coverage=0.2,
|
|
max_coverage=0.8,
|
|
num_blobs_range=(1, 3)
|
|
)
|
|
end = time.time()
|
|
# print time in milliseconds
|
|
print(f"Time taken: {(end - start)*1000:.2f} ms")
|
|
|
|
print(f"Saving masks to 'output' directory...")
|
|
save_masks_as_images(masks)
|
|
|
|
print("Done!")
|