import os import random from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms from torch.utils.data import Dataset from tqdm import tqdm class ImageDataset(Dataset): def __init__(self, config): self.config = config self.name = self.get_config('name', 'dataset') self.path = self.get_config('path', required=True) self.scale = self.get_config('scale', 1) self.random_scale = self.get_config('random_scale', False) # we always random crop if random scale is enabled self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False) self.resolution = self.get_config('resolution', 256) self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))] # this might take a while print(f" - Preprocessing image dimensions") new_file_list = [] bad_count = 0 for file in tqdm(self.file_list): img = Image.open(file) if int(min(img.size) * self.scale) >= self.resolution: new_file_list.append(file) else: bad_count += 1 print(f" - Found {len(self.file_list)} images") print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.path}" self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) def get_config(self, key, default=None, required=False): if key in self.config: value = self.config[key] return value elif required: raise ValueError(f'config file error. Missing "config.dataset.{key}" key') else: return default def __len__(self): return len(self.file_list) def __getitem__(self, index): img_path = self.file_list[index] img = exif_transpose(Image.open(img_path)).convert('RGB') # Downscale the source image first img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) min_img_size = min(img.size) if self.random_crop: if self.random_scale and min_img_size > self.resolution: if min_img_size < self.resolution: print( f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}") scale_size = self.resolution else: scale_size = random.randint(self.resolution, int(min_img_size)) img = img.resize((scale_size, scale_size), Image.BICUBIC) img = transforms.RandomCrop(self.resolution)(img) else: img = transforms.CenterCrop(min_img_size)(img) img = img.resize((self.resolution, self.resolution), Image.BICUBIC) img = self.transform(img) return img