mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
81 lines
3.0 KiB
Python
81 lines
3.0 KiB
Python
import os
|
|
import random
|
|
from PIL import Image
|
|
from PIL.ImageOps import exif_transpose
|
|
from torchvision import transforms
|
|
from torch.utils.data import Dataset
|
|
from tqdm import tqdm
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.name = self.get_config('name', 'dataset')
|
|
self.path = self.get_config('path', required=True)
|
|
self.scale = self.get_config('scale', 1)
|
|
self.random_scale = self.get_config('random_scale', False)
|
|
# we always random crop if random scale is enabled
|
|
self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
|
|
|
|
self.resolution = self.get_config('resolution', 256)
|
|
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
|
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
|
|
|
# this might take a while
|
|
print(f" - Preprocessing image dimensions")
|
|
new_file_list = []
|
|
bad_count = 0
|
|
for file in tqdm(self.file_list):
|
|
img = Image.open(file)
|
|
if int(min(img.size) * self.scale) >= self.resolution:
|
|
new_file_list.append(file)
|
|
else:
|
|
bad_count += 1
|
|
|
|
print(f" - Found {len(self.file_list)} images")
|
|
print(f" - Found {bad_count} images that are too small")
|
|
assert len(self.file_list) > 0, f"no images found in {self.path}"
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5]),
|
|
])
|
|
|
|
def get_config(self, key, default=None, required=False):
|
|
if key in self.config:
|
|
value = self.config[key]
|
|
return value
|
|
elif required:
|
|
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
|
|
else:
|
|
return default
|
|
|
|
def __len__(self):
|
|
return len(self.file_list)
|
|
|
|
def __getitem__(self, index):
|
|
img_path = self.file_list[index]
|
|
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
|
|
|
# Downscale the source image first
|
|
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
|
min_img_size = min(img.size)
|
|
|
|
if self.random_crop:
|
|
if self.random_scale and min_img_size > self.resolution:
|
|
if min_img_size < self.resolution:
|
|
print(
|
|
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
|
|
scale_size = self.resolution
|
|
else:
|
|
scale_size = random.randint(self.resolution, int(min_img_size))
|
|
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
|
img = transforms.RandomCrop(self.resolution)(img)
|
|
else:
|
|
img = transforms.CenterCrop(min_img_size)(img)
|
|
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
|
|
|
img = self.transform(img)
|
|
|
|
return img
|