mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
206 lines
7.6 KiB
Python
206 lines
7.6 KiB
Python
import os
|
|
import random
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
from PIL.ImageOps import exif_transpose
|
|
from torchvision import transforms
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
import albumentations as A
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.name = self.get_config('name', 'dataset')
|
|
self.path = self.get_config('path', required=True)
|
|
self.scale = self.get_config('scale', 1)
|
|
self.random_scale = self.get_config('random_scale', False)
|
|
# we always random crop if random scale is enabled
|
|
self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
|
|
|
|
self.resolution = self.get_config('resolution', 256)
|
|
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
|
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
|
|
|
# this might take a while
|
|
print(f" - Preprocessing image dimensions")
|
|
new_file_list = []
|
|
bad_count = 0
|
|
for file in tqdm(self.file_list):
|
|
img = Image.open(file)
|
|
if int(min(img.size) * self.scale) >= self.resolution:
|
|
new_file_list.append(file)
|
|
else:
|
|
bad_count += 1
|
|
|
|
print(f" - Found {len(self.file_list)} images")
|
|
print(f" - Found {bad_count} images that are too small")
|
|
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
|
])
|
|
|
|
def get_config(self, key, default=None, required=False):
|
|
if key in self.config:
|
|
value = self.config[key]
|
|
return value
|
|
elif required:
|
|
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
|
|
else:
|
|
return default
|
|
|
|
def __len__(self):
|
|
return len(self.file_list)
|
|
|
|
def __getitem__(self, index):
|
|
img_path = self.file_list[index]
|
|
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
|
|
|
# Downscale the source image first
|
|
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
|
min_img_size = min(img.size)
|
|
|
|
if self.random_crop:
|
|
if self.random_scale and min_img_size > self.resolution:
|
|
if min_img_size < self.resolution:
|
|
print(
|
|
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
|
scale_size = self.resolution
|
|
else:
|
|
scale_size = random.randint(self.resolution, int(min_img_size))
|
|
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
|
img = transforms.RandomCrop(self.resolution)(img)
|
|
else:
|
|
img = transforms.CenterCrop(min_img_size)(img)
|
|
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
|
|
|
img = self.transform(img)
|
|
|
|
return img
|
|
|
|
|
|
class Augments:
|
|
def __init__(self, **kwargs):
|
|
self.method_name = kwargs.get('method', None)
|
|
self.params = kwargs.get('params', {})
|
|
|
|
# convert kwargs enums for cv2
|
|
for key, value in self.params.items():
|
|
if isinstance(value, str):
|
|
# split the string
|
|
split_string = value.split('.')
|
|
if len(split_string) == 2 and split_string[0] == 'cv2':
|
|
if hasattr(cv2, split_string[1]):
|
|
self.params[key] = getattr(cv2, split_string[1].upper())
|
|
else:
|
|
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
|
|
|
|
|
class AugmentedImageDataset(ImageDataset):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.augmentations = self.get_config('augmentations', [])
|
|
self.augmentations = [Augments(**aug) for aug in self.augmentations]
|
|
|
|
augmentation_list = []
|
|
for aug in self.augmentations:
|
|
# make sure method name is valid
|
|
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
|
|
# get the method
|
|
method = getattr(A, aug.method_name)
|
|
# add the method to the list
|
|
augmentation_list.append(method(**aug.params))
|
|
|
|
self.aug_transform = A.Compose(augmentation_list)
|
|
self.original_transform = self.transform
|
|
# replace transform so we get raw pil image
|
|
self.transform = transforms.Compose([])
|
|
|
|
def __getitem__(self, index):
|
|
# get the original image
|
|
# image is a PIL image, convert to bgr
|
|
pil_image = super().__getitem__(index)
|
|
open_cv_image = np.array(pil_image)
|
|
# Convert RGB to BGR
|
|
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
|
|
|
# apply augmentations
|
|
augmented = self.aug_transform(image=open_cv_image)["image"]
|
|
|
|
# convert back to RGB tensor
|
|
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
|
|
|
# convert to PIL image
|
|
augmented = Image.fromarray(augmented)
|
|
|
|
# return both # return image as 0 - 1 tensor
|
|
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
|
|
|
|
|
|
class PairedImageDataset(Dataset):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.size = self.get_config('size', 512)
|
|
self.path = self.get_config('path', required=True)
|
|
self.default_prompt = self.get_config('default_prompt', '')
|
|
self.network_weight = self.get_config('network_weight', 1.0)
|
|
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
|
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
|
print(f" - Found {len(self.file_list)} images")
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
|
])
|
|
|
|
def __len__(self):
|
|
return len(self.file_list)
|
|
|
|
def get_config(self, key, default=None, required=False):
|
|
if key in self.config:
|
|
value = self.config[key]
|
|
return value
|
|
elif required:
|
|
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
|
|
else:
|
|
return default
|
|
|
|
def __getitem__(self, index):
|
|
img_path = self.file_list[index]
|
|
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
|
|
|
# see if prompt file exists
|
|
path_no_ext = os.path.splitext(img_path)[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
if os.path.exists(prompt_path):
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt = f.read()
|
|
# remove any newlines
|
|
prompt = prompt.replace('\n', ', ')
|
|
# remove new lines for all operating systems
|
|
prompt = prompt.replace('\r', ', ')
|
|
prompt_split = prompt.split(',')
|
|
# remove empty strings
|
|
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
|
# join back together
|
|
prompt = ', '.join(prompt_split)
|
|
else:
|
|
prompt = self.default_prompt
|
|
|
|
height = self.size
|
|
# determine width to keep aspect ratio
|
|
width = int(img.size[0] * height / img.size[1])
|
|
|
|
# Downscale the source image first
|
|
img = img.resize((width, height), Image.BICUBIC)
|
|
img = self.transform(img)
|
|
|
|
return img, prompt, self.network_weight
|
|
|