diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index bc6453f4..fcc6cda9 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -24,6 +24,8 @@ from toolkit.train_tools import get_torch_dtype from toolkit.models.pixtral_vision import PixtralVisionEncoderCompatible, PixtralVisionImagePreprocessorCompatible import random +from toolkit.util.mask import generate_random_mask + sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from collections import OrderedDict @@ -575,10 +577,27 @@ class CustomAdapter(torch.nn.Module): inpainting_latent = None if self.config.has_inpainting_input: do_dropout = random.random() < self.config.control_image_dropout - if batch.inpaint_tensor is not None and not do_dropout: - # currently 0-1, we need rgb to be -1 to 1 before encoding with the vae - inpainting_tensor_rgba = batch.inpaint_tensor.to(latents.device, dtype=latents.dtype) - inpainting_tensor_mask = inpainting_tensor_rgba[:, 3:4, :, :] + # do random mask if we dont have one + inpaint_tensor = batch.inpaint_tensor + if inpaint_tensor is None and not do_dropout: + # generate a random one since we dont have one + # this will make random blobs, invert the blobs for now as we normanlly inpaint the alpha + inpaint_tensor = 1 - generate_random_mask( + batch_size=latents.shape[0], + height=latents.shape[2], + width=latents.shape[3], + device=latents.device, + ).to(latents.device, latents.dtype) + if inpaint_tensor is not None and not do_dropout: + + if inpaint_tensor.shape[1] == 4: + # get just the mask + inpainting_tensor_mask = inpaint_tensor[:, 3:4, :, :].to(latents.device, dtype=latents.dtype) + elif inpaint_tensor.shape[1] == 3: + # rgb mask. Just get one channel + inpainting_tensor_mask = inpaint_tensor[:, 0:1, :, :].to(latents.device, dtype=latents.dtype) + else: + inpainting_tensor_mask = inpaint_tensor # # use our batch latents so we cna avoid ancoding again inpainting_latent = batch.latents diff --git a/toolkit/util/mask.py b/toolkit/util/mask.py new file mode 100644 index 00000000..65334281 --- /dev/null +++ b/toolkit/util/mask.py @@ -0,0 +1,226 @@ +import torch +import numpy as np +import os +import torch.nn.functional as F +from PIL import Image +import time + + +def generate_random_mask( + batch_size, + height=256, + width=256, + device='cuda', + min_coverage=0.2, + max_coverage=0.8, + num_blobs_range=(1, 3) +): + """ + Generate random blob masks for a batch of images. + Fast GPU version with smooth, non-circular blob shapes. + + Args: + batch_size (int): Number of masks to generate + height (int): Height of the mask + width (int): Width of the mask + device (str): Device to run the computation on ('cuda' or 'cpu') + min_coverage (float): Minimum percentage of the image to be covered (0-1) + max_coverage (float): Maximum percentage of the image to be covered (0-1) + num_blobs_range (tuple): Range of number of blobs (min, max) + + Returns: + torch.Tensor: Binary masks with shape (batch_size, 1, height, width) + """ + # Initialize masks on GPU + masks = torch.zeros((batch_size, 1, height, width), device=device) + + # Pre-compute coordinate grid on GPU + y_indices = torch.arange(height, device=device).view( + height, 1).expand(height, width) + x_indices = torch.arange(width, device=device).view( + 1, width).expand(height, width) + + # Prepare gaussian kernels for smoothing + small_kernel = get_gaussian_kernel(7, 1.0).to(device) + small_kernel = small_kernel.view(1, 1, 7, 7) + + large_kernel = get_gaussian_kernel(15, 2.5).to(device) + large_kernel = large_kernel.view(1, 1, 15, 15) + + # Constants + max_radius = min(height, width) // 3 + min_radius = min(height, width) // 8 + + # For each mask in the batch + for b in range(batch_size): + # Determine number of blobs for this mask + num_blobs = np.random.randint( + num_blobs_range[0], num_blobs_range[1] + 1) + + # Target coverage for this mask + target_coverage = np.random.uniform(min_coverage, max_coverage) + + # Initialize this mask + mask = torch.zeros(1, 1, height, width, device=device) + + # Generate blobs with smoother edges + for _ in range(num_blobs): + # Create a low-frequency noise field first (for smooth organic shapes) + noise_field = torch.zeros(height, width, device=device) + + # Use low-frequency sine waves to create base shape distortion + # This creates smoother warping compared to pure random noise + num_waves = np.random.randint(2, 5) + for i in range(num_waves): + freq_x = np.random.uniform(1.0, 3.0) * np.pi / width + freq_y = np.random.uniform(1.0, 3.0) * np.pi / height + phase_x = np.random.uniform(0, 2 * np.pi) + phase_y = np.random.uniform(0, 2 * np.pi) + amp = np.random.uniform(0.5, 1.0) * max_radius / (i+1.5) + + # Generate smooth wave patterns + wave = torch.sin(x_indices * freq_x + phase_x) * \ + torch.sin(y_indices * freq_y + phase_y) * amp + noise_field += wave + + # Basic ellipse parameters + center_y = np.random.randint(height//4, 3*height//4) + center_x = np.random.randint(width//4, 3*width//4) + radius = np.random.randint(min_radius, max_radius) + + # Squeeze and stretch the ellipse with random scaling + scale_y = np.random.uniform(0.6, 1.4) + scale_x = np.random.uniform(0.6, 1.4) + + # Random rotation + theta = np.random.uniform(0, 2 * np.pi) + cos_theta, sin_theta = np.cos(theta), np.sin(theta) + + # Calculate elliptical distance field + y_scaled = (y_indices - center_y) * scale_y + x_scaled = (x_indices - center_x) * scale_x + + # Apply rotation + rotated_y = y_scaled * cos_theta - x_scaled * sin_theta + rotated_x = y_scaled * sin_theta + x_scaled * cos_theta + + # Compute distances + distances = torch.sqrt(rotated_y**2 + rotated_x**2) + + # Apply the smooth noise field to the distance field + perturbed_distances = distances + noise_field + + # Create base blob + blob = (perturbed_distances < radius).float( + ).unsqueeze(0).unsqueeze(0) + + # Apply strong smoothing for very smooth edges + # Double smoothing to get really organic edges + blob = F.pad(blob, (7, 7, 7, 7), mode='reflect') + blob = F.conv2d(blob, large_kernel, padding=0) + + # Apply threshold to get a nice shape + rand_threshold = np.random.uniform(0.3, 0.6) + blob = (blob > rand_threshold).float() + + # Apply second smoothing pass + blob = F.pad(blob, (3, 3, 3, 3), mode='reflect') + blob = F.conv2d(blob, small_kernel, padding=0) + blob = (blob > 0.5).float() + + # Add to mask + mask = torch.maximum(mask, blob) + + # Ensure desired coverage + current_coverage = mask.mean().item() + + # Scale if needed to match target coverage + if current_coverage > 0: # Avoid division by zero + if current_coverage < target_coverage * 0.7: # Too small + # Dilate mask to increase coverage + mask = F.pad(mask, (2, 2, 2, 2), mode='reflect') + mask = F.max_pool2d(mask, kernel_size=5, stride=1, padding=0) + elif current_coverage > target_coverage * 1.3: # Too large + # Erode mask to decrease coverage + mask = F.pad(mask, (1, 1, 1, 1), mode='reflect') + mask = F.avg_pool2d(mask, kernel_size=3, stride=1, padding=0) + mask = (mask > 0.7).float() + + # Final smooth and threshold + mask = F.pad(mask, (3, 3, 3, 3), mode='reflect') + mask = F.conv2d(mask, small_kernel, padding=0) + mask = (mask > 0.5).float() + + # Add to batch + masks[b] = mask + + return masks + + +def get_gaussian_kernel(kernel_size=5, sigma=1.0): + """ + Returns a 2D Gaussian kernel. + """ + # Create 1D kernels + x = torch.linspace(-sigma * 2, sigma * 2, kernel_size) + x = x.view(1, -1).repeat(kernel_size, 1) + y = x.transpose(0, 1) + + # 2D Gaussian + gaussian = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) + gaussian /= gaussian.sum() + + return gaussian + + +def save_masks_as_images(masks, output_dir="output"): + """ + Save generated masks as RGB JPG images using PIL. + """ + os.makedirs(output_dir, exist_ok=True) + + batch_size = masks.shape[0] + for i in range(batch_size): + # Convert mask to numpy array + mask = masks[i, 0].cpu().numpy() + + # Scale to 0-255 range and convert to uint8 + mask_255 = (mask * 255).astype(np.uint8) + + # Create RGB image (white mask on black background) + rgb_mask = np.stack([mask_255, mask_255, mask_255], axis=2) + + # Convert to PIL Image and save + img = Image.fromarray(rgb_mask) + img.save(os.path.join(output_dir, f"mask_{i:03d}.jpg"), quality=95) + + +if __name__ == "__main__": + # Parameters + batch_size = 20 + height = 256 + width = 256 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + print(f"Generating {batch_size} random blob masks on {device}...") + + for i in range(5): + # time it + start = time.time() + masks = generate_random_mask( + batch_size=batch_size, + height=height, + width=width, + device=device, + min_coverage=0.2, + max_coverage=0.8, + num_blobs_range=(1, 3) + ) + end = time.time() + # print time in milliseconds + print(f"Time taken: {(end - start)*1000:.2f} ms") + + print(f"Saving masks to 'output' directory...") + save_masks_as_images(masks) + + print("Done!")