mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added an inpainting mask generator for training inpainting if inpaint mask is not provided
This commit is contained in:
@@ -24,6 +24,8 @@ from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible
|
||||
import random
|
||||
|
||||
from toolkit.util.mask import generate_random_mask
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
@@ -575,10 +577,27 @@ class CustomAdapter(torch.nn.Module):
|
||||
inpainting_latent = None
|
||||
if self.config.has_inpainting_input:
|
||||
do_dropout = random.random() < self.config.control_image_dropout
|
||||
if batch.inpaint_tensor is not None and not do_dropout:
|
||||
# currently 0-1, we need rgb to be -1 to 1 before encoding with the vae
|
||||
inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype)
|
||||
inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :]
|
||||
# do random mask if we dont have one
|
||||
inpaint_tensor = batch.inpaint_tensor
|
||||
if inpaint_tensor is None and not do_dropout:
|
||||
# generate a random one since we dont have one
|
||||
# this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha
|
||||
inpaint_tensor = 1 - generate_random_mask(
|
||||
batch_size=latents.shape[0],
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
device=latents.device,
|
||||
).to(latents.device, latents.dtype)
|
||||
if inpaint_tensor is not None and not do_dropout:
|
||||
|
||||
if inpaint_tensor.shape[1] == 4:
|
||||
# get just the mask
|
||||
inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype)
|
||||
elif inpaint_tensor.shape[1] == 3:
|
||||
# rgb mask. Just get one channel
|
||||
inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype)
|
||||
else:
|
||||
inpainting_tensor_mask = inpaint_tensor
|
||||
|
||||
# # use our batch latents so we cna avoid ancoding again
|
||||
inpainting_latent = batch.latents
|
||||
|
||||
226
toolkit/util/mask.py
Normal file
226
toolkit/util/mask.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
import time
|
||||
|
||||
|
||||
def generate_random_mask(
|
||||
batch_size,
|
||||
height=256,
|
||||
width=256,
|
||||
device='cuda',
|
||||
min_coverage=0.2,
|
||||
max_coverage=0.8,
|
||||
num_blobs_range=(1, 3)
|
||||
):
|
||||
"""
|
||||
Generate random blob masks for a batch of images.
|
||||
Fast GPU version with smooth, non-circular blob shapes.
|
||||
|
||||
Args:
|
||||
batch_size (int): Number of masks to generate
|
||||
height (int): Height of the mask
|
||||
width (int): Width of the mask
|
||||
device (str): Device to run the computation on ('cuda' or 'cpu')
|
||||
min_coverage (float): Minimum percentage of the image to be covered (0-1)
|
||||
max_coverage (float): Maximum percentage of the image to be covered (0-1)
|
||||
num_blobs_range (tuple): Range of number of blobs (min, max)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Binary masks with shape (batch_size, 1, height, width)
|
||||
"""
|
||||
# Initialize masks on GPU
|
||||
masks = torch.zeros((batch_size, 1, height, width), device=device)
|
||||
|
||||
# Pre-compute coordinate grid on GPU
|
||||
y_indices = torch.arange(height, device=device).view(
|
||||
height, 1).expand(height, width)
|
||||
x_indices = torch.arange(width, device=device).view(
|
||||
1, width).expand(height, width)
|
||||
|
||||
# Prepare gaussian kernels for smoothing
|
||||
small_kernel = get_gaussian_kernel(7, 1.0).to(device)
|
||||
small_kernel = small_kernel.view(1, 1, 7, 7)
|
||||
|
||||
large_kernel = get_gaussian_kernel(15, 2.5).to(device)
|
||||
large_kernel = large_kernel.view(1, 1, 15, 15)
|
||||
|
||||
# Constants
|
||||
max_radius = min(height, width) // 3
|
||||
min_radius = min(height, width) // 8
|
||||
|
||||
# For each mask in the batch
|
||||
for b in range(batch_size):
|
||||
# Determine number of blobs for this mask
|
||||
num_blobs = np.random.randint(
|
||||
num_blobs_range[0], num_blobs_range[1] + 1)
|
||||
|
||||
# Target coverage for this mask
|
||||
target_coverage = np.random.uniform(min_coverage, max_coverage)
|
||||
|
||||
# Initialize this mask
|
||||
mask = torch.zeros(1, 1, height, width, device=device)
|
||||
|
||||
# Generate blobs with smoother edges
|
||||
for _ in range(num_blobs):
|
||||
# Create a low-frequency noise field first (for smooth organic shapes)
|
||||
noise_field = torch.zeros(height, width, device=device)
|
||||
|
||||
# Use low-frequency sine waves to create base shape distortion
|
||||
# This creates smoother warping compared to pure random noise
|
||||
num_waves = np.random.randint(2, 5)
|
||||
for i in range(num_waves):
|
||||
freq_x = np.random.uniform(1.0, 3.0) * np.pi / width
|
||||
freq_y = np.random.uniform(1.0, 3.0) * np.pi / height
|
||||
phase_x = np.random.uniform(0, 2 * np.pi)
|
||||
phase_y = np.random.uniform(0, 2 * np.pi)
|
||||
amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5)
|
||||
|
||||
# Generate smooth wave patterns
|
||||
wave = torch.sin(x_indices * freq_x + phase_x) * \
|
||||
torch.sin(y_indices * freq_y + phase_y) * amp
|
||||
noise_field += wave
|
||||
|
||||
# Basic ellipse parameters
|
||||
center_y = np.random.randint(height//4, 3*height//4)
|
||||
center_x = np.random.randint(width//4, 3*width//4)
|
||||
radius = np.random.randint(min_radius, max_radius)
|
||||
|
||||
# Squeeze and stretch the ellipse with random scaling
|
||||
scale_y = np.random.uniform(0.6, 1.4)
|
||||
scale_x = np.random.uniform(0.6, 1.4)
|
||||
|
||||
# Random rotation
|
||||
theta = np.random.uniform(0, 2 * np.pi)
|
||||
cos_theta, sin_theta = np.cos(theta), np.sin(theta)
|
||||
|
||||
# Calculate elliptical distance field
|
||||
y_scaled = (y_indices - center_y) * scale_y
|
||||
x_scaled = (x_indices - center_x) * scale_x
|
||||
|
||||
# Apply rotation
|
||||
rotated_y = y_scaled * cos_theta - x_scaled * sin_theta
|
||||
rotated_x = y_scaled * sin_theta + x_scaled * cos_theta
|
||||
|
||||
# Compute distances
|
||||
distances = torch.sqrt(rotated_y**2 + rotated_x**2)
|
||||
|
||||
# Apply the smooth noise field to the distance field
|
||||
perturbed_distances = distances + noise_field
|
||||
|
||||
# Create base blob
|
||||
blob = (perturbed_distances < radius).float(
|
||||
).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Apply strong smoothing for very smooth edges
|
||||
# Double smoothing to get really organic edges
|
||||
blob = F.pad(blob, (7, 7, 7, 7), mode='reflect')
|
||||
blob = F.conv2d(blob, large_kernel, padding=0)
|
||||
|
||||
# Apply threshold to get a nice shape
|
||||
rand_threshold = np.random.uniform(0.3, 0.6)
|
||||
blob = (blob > rand_threshold).float()
|
||||
|
||||
# Apply second smoothing pass
|
||||
blob = F.pad(blob, (3, 3, 3, 3), mode='reflect')
|
||||
blob = F.conv2d(blob, small_kernel, padding=0)
|
||||
blob = (blob > 0.5).float()
|
||||
|
||||
# Add to mask
|
||||
mask = torch.maximum(mask, blob)
|
||||
|
||||
# Ensure desired coverage
|
||||
current_coverage = mask.mean().item()
|
||||
|
||||
# Scale if needed to match target coverage
|
||||
if current_coverage > 0: # Avoid division by zero
|
||||
if current_coverage < target_coverage * 0.7: # Too small
|
||||
# Dilate mask to increase coverage
|
||||
mask = F.pad(mask, (2, 2, 2, 2), mode='reflect')
|
||||
mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0)
|
||||
elif current_coverage > target_coverage * 1.3: # Too large
|
||||
# Erode mask to decrease coverage
|
||||
mask = F.pad(mask, (1, 1, 1, 1), mode='reflect')
|
||||
mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0)
|
||||
mask = (mask > 0.7).float()
|
||||
|
||||
# Final smooth and threshold
|
||||
mask = F.pad(mask, (3, 3, 3, 3), mode='reflect')
|
||||
mask = F.conv2d(mask, small_kernel, padding=0)
|
||||
mask = (mask > 0.5).float()
|
||||
|
||||
# Add to batch
|
||||
masks[b] = mask
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def get_gaussian_kernel(kernel_size=5, sigma=1.0):
|
||||
"""
|
||||
Returns a 2D Gaussian kernel.
|
||||
"""
|
||||
# Create 1D kernels
|
||||
x = torch.linspace(-sigma * 2, sigma * 2, kernel_size)
|
||||
x = x.view(1, -1).repeat(kernel_size, 1)
|
||||
y = x.transpose(0, 1)
|
||||
|
||||
# 2D Gaussian
|
||||
gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
|
||||
gaussian /= gaussian.sum()
|
||||
|
||||
return gaussian
|
||||
|
||||
|
||||
def save_masks_as_images(masks, output_dir="output"):
|
||||
"""
|
||||
Save generated masks as RGB JPG images using PIL.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
batch_size = masks.shape[0]
|
||||
for i in range(batch_size):
|
||||
# Convert mask to numpy array
|
||||
mask = masks[i, 0].cpu().numpy()
|
||||
|
||||
# Scale to 0-255 range and convert to uint8
|
||||
mask_255 = (mask * 255).astype(np.uint8)
|
||||
|
||||
# Create RGB image (white mask on black background)
|
||||
rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2)
|
||||
|
||||
# Convert to PIL Image and save
|
||||
img = Image.fromarray(rgb_mask)
|
||||
img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parameters
|
||||
batch_size = 20
|
||||
height = 256
|
||||
width = 256
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
print(f"Generating {batch_size} random blob masks on {device}...")
|
||||
|
||||
for i in range(5):
|
||||
# time it
|
||||
start = time.time()
|
||||
masks = generate_random_mask(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
min_coverage=0.2,
|
||||
max_coverage=0.8,
|
||||
num_blobs_range=(1, 3)
|
||||
)
|
||||
end = time.time()
|
||||
# print time in milliseconds
|
||||
print(f"Time taken: {(end - start)*1000:.2f} ms")
|
||||
|
||||
print(f"Saving masks to 'output' directory...")
|
||||
save_masks_as_images(masks)
|
||||
|
||||
print("Done!")
|
||||
Reference in New Issue
Block a user