Files
ai-toolkit/testing/test_bucket_dataloader.py

221 lines
7.0 KiB
Python

import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import sys
import os
import cv2
import random
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.paths import SD_SCRIPTS_ROOT
import torchvision.transforms.functional
from toolkit.image_utils import show_img
sys.path.append(SD_SCRIPTS_ROOT)
from library.model_util import load_vae
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 1
##
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
control_path=dataset_folder,
resolution=resolution,
# caption_ext='json',
default_caption='default',
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
buckets=True,
bucket_tolerance=bucket_tolerance,
# poi='person',
# shuffle_augmentations=True,
# augmentations=[
# {
# 'method': 'Posterize',
# 'num_bits': [(0, 4), (0, 4), (0, 4)],
# 'p': 1.0
# },
#
# ]
)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
def random_blur(img, min_kernel_size=3, max_kernel_size=23, p=0.5):
if random.random() < p:
kernel_size = random.randint(min_kernel_size, max_kernel_size)
# make sure it is odd
if kernel_size % 2 == 0:
kernel_size += 1
img = torchvision.transforms.functional.gaussian_blur(img, kernel_size=kernel_size)
return img
def quantize(image, palette):
"""
Similar to PIL.Image.quantize() in PyTorch. Built to maintain gradient.
Only works for one image i.e. CHW. Does NOT work for batches.
ref https://discuss.pytorch.org/t/color-quantization/104528/4
"""
orig_dtype = image.dtype
C, H, W = image.shape
n_colors = palette.shape[0]
# Easier to work with list of colors
flat_img = image.view(C, -1).T # [C, H, W] -> [H*W, C]
# Repeat image so that there are n_color number of columns of the same image
flat_img_per_color = flat_img.unsqueeze(1).expand(-1, n_colors, -1) # [H*W, C] -> [H*W, n_colors, C]
# Get euclidean distance between each pixel in each column and the column's respective color
# i.e. column 1 lists distance of each pixel to color #1 in palette, column 2 to color #2 etc.
squared_distance = (flat_img_per_color - palette.unsqueeze(0)) ** 2
euclidean_distance = torch.sqrt(torch.sum(squared_distance, dim=-1) + 1e-8) # [H*W, n_colors, C] -> [H*W, n_colors]
# Get the shortest distance (one value per row (H*W) is selected)
min_distances, min_indices = torch.min(euclidean_distance, dim=-1) # [H*W, n_colors] -> [H*W]
# Create a mask for the closest colors
one_hot_mask = torch.nn.functional.one_hot(min_indices, num_classes=n_colors).float() # [H*W, n_colors]
# Multiply the mask with the palette colors to get the quantized image
quantized = torch.matmul(one_hot_mask, palette) # [H*W, n_colors] @ [n_colors, C] -> [H*W, C]
# Reshape it back to the original input format.
quantized_img = quantized.T.view(C, H, W) # [H*W, C] -> [C, H, W]
return quantized_img.to(orig_dtype)
def color_block_imgs(img, neg1_1=False):
# expects values 0 - 1
orig_dtype = img.dtype
if neg1_1:
img = img * 0.5 + 0.5
img = img * 255
img = img.clamp(0, 255)
img = img.to(torch.uint8)
img_chunks = torch.chunk(img, img.shape[0], dim=0)
posterized_chunks = []
for chunk in img_chunks:
img_size = (chunk.shape[2] + chunk.shape[3]) // 2
# min kernel size of 1% of image, max 10%
min_kernel_size = int(img_size * 0.01)
max_kernel_size = int(img_size * 0.1)
# blur first
chunk = random_blur(chunk, min_kernel_size=min_kernel_size, max_kernel_size=max_kernel_size, p=0.8)
num_colors = random.randint(1, 16)
resize_to = 16
# chunk = torchvision.transforms.functional.posterize(chunk, num_bits_to_use)
# mean_color = [int(x.item()) for x in torch.mean(chunk.float(), dim=(0, 2, 3))]
# shrink the image down to num_colors x num_colors
shrunk = torchvision.transforms.functional.resize(chunk, [resize_to, resize_to])
mean_color = [int(x.item()) for x in torch.mean(shrunk.float(), dim=(0, 2, 3))]
colors = shrunk.view(3, -1).T
# remove duplicates
colors = torch.unique(colors, dim=0)
colors = colors.numpy()
colors = colors.tolist()
use_colors = [random.choice(colors) for _ in range(num_colors)]
pallette = torch.tensor([
[0, 0, 0],
mean_color,
[255, 255, 255],
] + use_colors, dtype=torch.float32)
chunk = quantize(chunk.squeeze(0), pallette).unsqueeze(0)
# chunk = torchvision.transforms.functional.equalize(chunk)
# color jitter
if random.random() < 0.5:
chunk = torchvision.transforms.functional.adjust_contrast(chunk, random.uniform(1.0, 1.5))
if random.random() < 0.5:
chunk = torchvision.transforms.functional.adjust_saturation(chunk, random.uniform(1.0, 2.0))
# if random.random() < 0.5:
# chunk = torchvision.transforms.functional.adjust_brightness(chunk, random.uniform(0.5, 1.5))
chunk = random_blur(chunk, p=0.6)
posterized_chunks.append(chunk)
img = torch.cat(posterized_chunks, dim=0)
img = img.to(orig_dtype)
img = img / 255
if neg1_1:
img = img * 2 - 1
return img
# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
for epoch in range(args.epochs):
for batch in tqdm(dataloader):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
# img_batch = color_block_imgs(img_batch, neg1_1=True)
chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side
big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0)
control_chunks = torch.chunk(batch.control_tensor, batch_size, dim=0)
big_control_img = torch.cat(control_chunks, dim=3)
big_control_img = big_control_img.squeeze(0) * 2 - 1
big_img = torch.cat([big_img, big_control_img], dim=2)
min_val = big_img.min()
max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1)
# convert to image
img = transforms.ToPILImage()(big_img)
# show_img(img)
# time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
cv2.destroyAllWindows()
print('done')