mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
added flipping x and y for dataset loader
This commit is contained in:
@@ -411,13 +411,14 @@ class TrainESRGANProcess(BaseTrainProcess):
|
||||
def run(self):
|
||||
super().run()
|
||||
self.load_datasets()
|
||||
steps_per_step = (self.critic.num_critic_per_gen + 1)
|
||||
|
||||
max_step_epochs = self.max_steps // len(self.data_loader)
|
||||
max_step_epochs = self.max_steps // (len(self.data_loader) // steps_per_step)
|
||||
num_epochs = self.epochs
|
||||
if num_epochs is None or num_epochs > max_step_epochs:
|
||||
num_epochs = max_step_epochs
|
||||
|
||||
max_epoch_steps = len(self.data_loader) * num_epochs
|
||||
max_epoch_steps = len(self.data_loader) * num_epochs * steps_per_step
|
||||
num_steps = self.max_steps
|
||||
if num_steps is None or num_steps > max_epoch_steps:
|
||||
num_steps = max_epoch_steps
|
||||
|
||||
@@ -217,6 +217,9 @@ class DatasetConfig:
|
||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@@ -383,9 +384,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
path=file,
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
# if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
|
||||
# bad_count += 1
|
||||
# else:
|
||||
self.file_list.append(file_item)
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {file}")
|
||||
@@ -396,6 +394,29 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
# print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||
|
||||
# handle x axis flips
|
||||
if self.dataset_config.flip_x:
|
||||
print(" - adding x axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the x axis
|
||||
new_file_item = copy.deepcopy(file_item)
|
||||
new_file_item.flip_x = True
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
# handle y axis flips
|
||||
if self.dataset_config.flip_y:
|
||||
print(" - adding y axis flips")
|
||||
current_file_list = [x for x in self.file_list]
|
||||
for file_item in current_file_list:
|
||||
# create a copy that is flipped on the y axis
|
||||
new_file_item = copy.deepcopy(file_item)
|
||||
new_file_item.flip_y = True
|
||||
self.file_list.append(new_file_item)
|
||||
|
||||
if self.dataset_config.flip_x or self.dataset_config.flip_y:
|
||||
print(f" - Found {len(self.file_list)} images after adding flips")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
|
||||
@@ -47,6 +47,8 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag
|
||||
self.crop_y: int = kwargs.get('crop_y', 0)
|
||||
self.crop_width: int = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height: int = kwargs.get('crop_height', self.scale_to_height)
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_x', False)
|
||||
|
||||
self.network_weight: float = self.dataset_config.network_weight
|
||||
self.is_reg = self.dataset_config.is_reg
|
||||
|
||||
@@ -251,6 +251,13 @@ class ImageProcessingDTOMixin:
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# todo allow scaling and cropping, will be hard to add
|
||||
# scale and crop based on file item
|
||||
@@ -258,6 +265,7 @@ class ImageProcessingDTOMixin:
|
||||
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
else:
|
||||
# Downscale the source image first
|
||||
# TODO this is nto right
|
||||
img = img.resize(
|
||||
(int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
|
||||
Image.BICUBIC)
|
||||
@@ -270,7 +278,10 @@ class ImageProcessingDTOMixin:
|
||||
scale_size = self.dataset_config.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
scaler = scale_size / min_img_size
|
||||
scale_width = int((img.width + 5) * scaler)
|
||||
scale_height = int((img.height + 5) * scaler)
|
||||
img = img.resize((scale_width, scale_height), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.dataset_config.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
@@ -299,7 +310,7 @@ class LatentCachingFileItemDTOMixin:
|
||||
self.latent_version = 1
|
||||
|
||||
def get_latent_info_dict(self: 'FileItemDTO'):
|
||||
return OrderedDict([
|
||||
item = OrderedDict([
|
||||
("filename", os.path.basename(self.path)),
|
||||
("scale_to_width", self.scale_to_width),
|
||||
("scale_to_height", self.scale_to_height),
|
||||
@@ -310,6 +321,12 @@ class LatentCachingFileItemDTOMixin:
|
||||
("latent_space_version", self.latent_space_version),
|
||||
("latent_version", self.latent_version),
|
||||
])
|
||||
# when adding items, do it after so we dont change old latents
|
||||
if self.flip_x:
|
||||
item["flip_x"] = True
|
||||
if self.flip_y:
|
||||
item["flip_y"] = True
|
||||
return item
|
||||
|
||||
def get_latent_path(self: 'FileItemDTO', recalculate=False):
|
||||
if self._latent_path is not None and not recalculate:
|
||||
|
||||
Reference in New Issue
Block a user