added flipping x and y for dataset loader

This commit is contained in:
Jaret Burkett
2023-09-17 08:42:54 -06:00
parent c698837241
commit 181f237a7b
5 changed files with 51 additions and 7 deletions

View File

@@ -411,13 +411,14 @@ class TrainESRGANProcess(BaseTrainProcess):
def run(self):
super().run()
self.load_datasets()
steps_per_step = (self.critic.num_critic_per_gen + 1)
max_step_epochs = self.max_steps // len(self.data_loader)
max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step)
num_epochs = self.epochs
if num_epochs is None or num_epochs > max_step_epochs:
num_epochs = max_step_epochs
max_epoch_steps = len(self.data_loader) * num_epochs
max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step
num_steps = self.max_steps
if num_steps is None or num_steps > max_epoch_steps:
num_steps = max_epoch_steps

View File

@@ -217,6 +217,9 @@ class DatasetConfig:
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
# cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False)

View File

@@ -1,3 +1,4 @@
import copy
import json
import os
import random
@@ -383,9 +384,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
path=file,
dataset_config=dataset_config
)
# if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
# bad_count += 1
# else:
self.file_list.append(file_item)
except Exception as e:
print(f"Error processing image: {file}")
@@ -396,6 +394,29 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
# handle x axis flips
if self.dataset_config.flip_x:
print(" - adding x axis flips")
current_file_list = [x for x in self.file_list]
for file_item in current_file_list:
# create a copy that is flipped on the x axis
new_file_item = copy.deepcopy(file_item)
new_file_item.flip_x = True
self.file_list.append(new_file_item)
# handle y axis flips
if self.dataset_config.flip_y:
print(" - adding y axis flips")
current_file_list = [x for x in self.file_list]
for file_item in current_file_list:
# create a copy that is flipped on the y axis
new_file_item = copy.deepcopy(file_item)
new_file_item.flip_y = True
self.file_list.append(new_file_item)
if self.dataset_config.flip_x or self.dataset_config.flip_y:
print(f" - Found {len(self.file_list)} images after adding flips")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]

View File

@@ -47,6 +47,8 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
self.crop_y: int = kwargs.get('crop_y', 0)
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_x', False)
self.network_weight: float = self.dataset_config.network_weight
self.is_reg = self.dataset_config.is_reg

View File

@@ -251,6 +251,13 @@ class ImageProcessingDTOMixin:
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
@@ -258,6 +265,7 @@ class ImageProcessingDTOMixin:
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
else:
# Downscale the source image first
# TODO this is nto right
img = img.resize(
(int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
Image.BICUBIC)
@@ -270,7 +278,10 @@ class ImageProcessingDTOMixin:
scale_size = self.dataset_config.resolution
else:
scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
scaler = scale_size / min_img_size
scale_width = int((img.width + 5) * scaler)
scale_height = int((img.height + 5) * scaler)
img = img.resize((scale_width, scale_height), Image.BICUBIC)
img = transforms.RandomCrop(self.dataset_config.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
@@ -299,7 +310,7 @@ class LatentCachingFileItemDTOMixin:
self.latent_version = 1
def get_latent_info_dict(self: 'FileItemDTO'):
return OrderedDict([
item = OrderedDict([
("filename", os.path.basename(self.path)),
("scale_to_width", self.scale_to_width),
("scale_to_height", self.scale_to_height),
@@ -310,6 +321,12 @@ class LatentCachingFileItemDTOMixin:
("latent_space_version", self.latent_space_version),
("latent_version", self.latent_version),
])
# when adding items, do it after so we dont change old latents
if self.flip_x:
item["flip_x"] = True
if self.flip_y:
item["flip_y"] = True
return item
def get_latent_path(self: 'FileItemDTO', recalculate=False):
if self._latent_path is not None and not recalculate: