From 181f237a7bc3d7e555f079d2000c39bd06495ab6 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 17 Sep 2023 08:42:54 -0600 Subject: [PATCH] added flipping x and y for dataset loader --- jobs/process/TrainESRGANProcess.py | 5 ++-- toolkit/config_modules.py | 3 +++ toolkit/data_loader.py | 27 ++++++++++++++++++--- toolkit/data_transfer_object/data_loader.py | 2 ++ toolkit/dataloader_mixins.py | 21 ++++++++++++++-- 5 files changed, 51 insertions(+), 7 deletions(-) diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py index 599eed17..c00d92bd 100644 --- a/jobs/process/TrainESRGANProcess.py +++ b/jobs/process/TrainESRGANProcess.py @@ -411,13 +411,14 @@ class TrainESRGANProcess(BaseTrainProcess): def run(self): super().run() self.load_datasets() + steps_per_step = (self.critic.num_critic_per_gen + 1) - max_step_epochs = self.max_steps // len(self.data_loader) + max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step) num_epochs = self.epochs if num_epochs is None or num_epochs > max_step_epochs: num_epochs = max_step_epochs - max_epoch_steps = len(self.data_loader) * num_epochs + max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step num_steps = self.max_steps if num_steps is None or num_steps > max_epoch_steps: num_steps = max_epoch_steps diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index bf7d286d..4c963267 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -217,6 +217,9 @@ class DatasetConfig: self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + self.flip_x: bool = kwargs.get('flip_x', False) + self.flip_y: bool = kwargs.get('flip_y', False) # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index bf3ffb9a..86214b1a 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -1,3 +1,4 @@ +import copy import json import os import random @@ -383,9 +384,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): path=file, dataset_config=dataset_config ) - # if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution: - # bad_count += 1 - # else: self.file_list.append(file_item) except Exception as e: print(f"Error processing image: {file}") @@ -396,6 +394,29 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): # print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" + # handle x axis flips + if self.dataset_config.flip_x: + print(" - adding x axis flips") + current_file_list = [x for x in self.file_list] + for file_item in current_file_list: + # create a copy that is flipped on the x axis + new_file_item = copy.deepcopy(file_item) + new_file_item.flip_x = True + self.file_list.append(new_file_item) + + # handle y axis flips + if self.dataset_config.flip_y: + print(" - adding y axis flips") + current_file_list = [x for x in self.file_list] + for file_item in current_file_list: + # create a copy that is flipped on the y axis + new_file_item = copy.deepcopy(file_item) + new_file_item.flip_y = True + self.file_list.append(new_file_item) + + if self.dataset_config.flip_x or self.dataset_config.flip_y: + print(f" - Found {len(self.file_list)} images after adding flips") + self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 59edf2c9..791acb3c 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -47,6 +47,8 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag self.crop_y: int = kwargs.get('crop_y', 0) self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) + self.flip_x: bool = kwargs.get('flip_x', False) + self.flip_y: bool = kwargs.get('flip_x', False) self.network_weight: float = self.dataset_config.network_weight self.is_reg = self.dataset_config.is_reg diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 536c8498..ba9c95d0 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -251,6 +251,13 @@ class ImageProcessingDTOMixin: raise ValueError( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + if self.flip_x: + # do a flip + img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img.transpose(Image.FLIP_TOP_BOTTOM) + if self.dataset_config.buckets: # todo allow scaling and cropping, will be hard to add # scale and crop based on file item @@ -258,6 +265,7 @@ class ImageProcessingDTOMixin: img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) else: # Downscale the source image first + # TODO this is nto right img = img.resize( (int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)), Image.BICUBIC) @@ -270,7 +278,10 @@ class ImageProcessingDTOMixin: scale_size = self.dataset_config.resolution else: scale_size = random.randint(self.dataset_config.resolution, int(min_img_size)) - img = img.resize((scale_size, scale_size), Image.BICUBIC) + scaler = scale_size / min_img_size + scale_width = int((img.width + 5) * scaler) + scale_height = int((img.height + 5) * scaler) + img = img.resize((scale_width, scale_height), Image.BICUBIC) img = transforms.RandomCrop(self.dataset_config.resolution)(img) else: img = transforms.CenterCrop(min_img_size)(img) @@ -299,7 +310,7 @@ class LatentCachingFileItemDTOMixin: self.latent_version = 1 def get_latent_info_dict(self: 'FileItemDTO'): - return OrderedDict([ + item = OrderedDict([ ("filename", os.path.basename(self.path)), ("scale_to_width", self.scale_to_width), ("scale_to_height", self.scale_to_height), @@ -310,6 +321,12 @@ class LatentCachingFileItemDTOMixin: ("latent_space_version", self.latent_space_version), ("latent_version", self.latent_version), ]) + # when adding items, do it after so we dont change old latents + if self.flip_x: + item["flip_x"] = True + if self.flip_y: + item["flip_y"] = True + return item def get_latent_path(self: 'FileItemDTO', recalculate=False): if self._latent_path is not None and not recalculate: