diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e77db6a0..e0e7a4d2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -12,7 +12,7 @@ import torch import torch.backends.cuda from toolkit.basic import value_map -from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding from toolkit.ip_adapter import IPAdapter @@ -931,16 +931,22 @@ class BaseSDTrainProcess(BaseTrainProcess): batch = next(dataloader_iterator_reg) except StopIteration: # hit the end of an epoch, reset + self.progress_bar.pause() dataloader_iterator_reg = iter(dataloader_reg) + trigger_dataloader_setup_epoch(dataloader_reg) batch = next(dataloader_iterator_reg) + self.progress_bar.unpause() is_reg_step = True elif dataloader is not None: try: batch = next(dataloader_iterator) except StopIteration: # hit the end of an epoch, reset + self.progress_bar.pause() dataloader_iterator = iter(dataloader) + trigger_dataloader_setup_epoch(dataloader) batch = next(dataloader_iterator) + self.progress_bar.unpause() else: batch = None diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 492357eb..1b3bae07 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -2,6 +2,7 @@ import time import numpy as np import torch +from torch.utils.data import DataLoader from torchvision import transforms import sys import os @@ -16,12 +17,14 @@ sys.path.append(SD_SCRIPTS_ROOT) from library.model_util import load_vae from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO -from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets +from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \ + trigger_dataloader_setup_epoch from toolkit.config_modules import DatasetConfig import argparse parser = argparse.ArgumentParser() parser.add_argument('dataset_folder', type=str, default='input') +parser.add_argument('--epochs', type=int, default=1) args = parser.parse_args() @@ -34,38 +37,44 @@ batch_size = 4 dataset_config = DatasetConfig( dataset_path=dataset_folder, resolution=resolution, - caption_ext='txt', + caption_ext='json', default_caption='default', buckets=True, bucket_tolerance=bucket_tolerance, - augments=['ColorJitter', 'RandomEqualize'], + augments=['ColorJitter'], + poi='person' ) -dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) +dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) # run through an epoch ang check sizes -for batch in dataloader: - batch: 'DataLoaderBatchDTO' - img_batch = batch.tensor +dataloader_iterator = iter(dataloader) +for epoch in range(args.epochs): + for batch in dataloader: + batch: 'DataLoaderBatchDTO' + img_batch = batch.tensor - chunks = torch.chunk(img_batch, batch_size, dim=0) - # put them so they are size by side - big_img = torch.cat(chunks, dim=3) - big_img = big_img.squeeze(0) + chunks = torch.chunk(img_batch, batch_size, dim=0) + # put them so they are size by side + big_img = torch.cat(chunks, dim=3) + big_img = big_img.squeeze(0) - min_val = big_img.min() - max_val = big_img.max() + min_val = big_img.min() + max_val = big_img.max() - big_img = (big_img / 2 + 0.5).clamp(0, 1) + big_img = (big_img / 2 + 0.5).clamp(0, 1) - # convert to image - img = transforms.ToPILImage()(big_img) + # convert to image + img = transforms.ToPILImage()(big_img) - show_img(img) + show_img(img) - time.sleep(1.0) + time.sleep(1.0) + # if not last epoch + if epoch < args.epochs - 1: + trigger_dataloader_setup_epoch(dataloader) cv2.destroyAllWindows() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a78252ea..3f875217 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -234,7 +234,8 @@ class DatasetConfig: self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) - self.control_path: str = kwargs.get('control_path', None) # depth maps, etc + self.control_path: str = kwargs.get('control_path', None) # depth maps, etc + self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 1af7e8ba..451ff4d2 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -346,6 +346,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk self.is_caching_latents_to_memory = dataset_config.cache_latents self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk + self.epoch_num = 0 self.sd = sd @@ -426,13 +427,20 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): self.setup_epoch() def setup_epoch(self): - # TODO: set this up to redo cropping and everything else - # do not call for now - if self.dataset_config.buckets: - # setup buckets - self.setup_buckets() - if self.is_caching_latents: - self.cache_latents_all_latents() + if self.epoch_num == 0: + # initial setup + # do not call for now + if self.dataset_config.buckets: + # setup buckets + self.setup_buckets() + if self.is_caching_latents: + self.cache_latents_all_latents() + else: + if self.dataset_config.poi is not None: + # handle cropping to a specific point of interest + # setup buckets every epoch + self.setup_buckets(quiet=True) + self.epoch_num += 1 def __len__(self): if self.dataset_config.buckets: @@ -450,6 +458,9 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): # for buckets we collate ourselves for now # todo allow a scheduler to dynamically make buckets # we collate ourselves + if len(self.batch_indices) - 1 < item: + # tried everything to solve this. No way to reset length when redoing things. Pick another index + item = random.randint(0, len(self.batch_indices) - 1) idx_list = self.batch_indices[item] return [self._get_single_item(idx) for idx in idx_list] else: @@ -523,3 +534,27 @@ def get_dataloader_from_datasets( collate_fn=dto_collation ) return data_loader + + +def trigger_dataloader_setup_epoch(dataloader: DataLoader): + # hacky but needed because of different types of datasets and dataloaders + dataloader.len = None + if isinstance(dataloader.dataset, list): + for dataset in dataloader.dataset: + if hasattr(dataset, 'datasets'): + for sub_dataset in dataset.datasets: + if hasattr(sub_dataset, 'setup_epoch'): + sub_dataset.setup_epoch() + sub_dataset.len = None + elif hasattr(dataset, 'setup_epoch'): + dataset.setup_epoch() + dataset.len = None + elif hasattr(dataloader.dataset, 'setup_epoch'): + dataloader.dataset.setup_epoch() + dataloader.dataset.len = None + elif hasattr(dataloader.dataset, 'datasets'): + dataloader.dataset.len = None + for sub_dataset in dataloader.dataset.datasets: + if hasattr(sub_dataset, 'setup_epoch'): + sub_dataset.setup_epoch() + sub_dataset.len = None diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 98f73bfd..5550dec3 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ - ControlFileItemDTOMixin, ArgBreakMixin + ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -27,6 +27,7 @@ class FileItemDTO( CaptionProcessingDTOMixin, ImageProcessingDTOMixin, ControlFileItemDTOMixin, + PoiFileItemDTOMixin, ArgBreakMixin, ): def __init__(self, *args, **kwargs): @@ -70,20 +71,25 @@ class FileItemDTO( class DataLoaderBatchDTO: def __init__(self, **kwargs): - self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) - is_latents_cached = self.file_items[0].is_latent_cached - self.tensor: Union[torch.Tensor, None] = None - self.latents: Union[torch.Tensor, None] = None - if not is_latents_cached: - # only return a tensor if latents are not cached - self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) - # if we have encoded latents, we concatenate them - self.latents: Union[torch.Tensor, None] = None - if is_latents_cached: - self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) - self.control_tensor: Union[torch.Tensor, None] = None - if self.file_items[0].control_tensor is not None: - self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items]) + try: + self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) + is_latents_cached = self.file_items[0].is_latent_cached + self.tensor: Union[torch.Tensor, None] = None + self.latents: Union[torch.Tensor, None] = None + if not is_latents_cached: + # only return a tensor if latents are not cached + self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) + # if we have encoded latents, we concatenate them + self.latents: Union[torch.Tensor, None] = None + if is_latents_cached: + self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) + self.control_tensor: Union[torch.Tensor, None] = None + if self.file_items[0].control_tensor is not None: + self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items]) + except Exception as e: + print(e) + raise e + def get_is_reg_list(self): return [x.is_reg for x in self.file_items] diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 4c9c902f..8e13aceb 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -29,11 +29,13 @@ if TYPE_CHECKING: transforms_dict = { - 'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01), + 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03), 'RandomEqualize': transforms.RandomEqualize(p=0.2), } caption_ext_list = ['txt', 'json', 'caption'] + + class CaptionMixin: def get_caption_item(self: 'AiToolkitDataset', index): if not hasattr(self, 'caption_type'): @@ -106,66 +108,96 @@ class BucketsMixin: self.batch_indices: List[List[int]] = [] def build_batch_indices(self: 'AiToolkitDataset'): + self.batch_indices = [] for key, bucket in self.buckets.items(): for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) batch = bucket.file_list_idx[start_idx:end_idx] self.batch_indices.append(batch) - def setup_buckets(self: 'AiToolkitDataset'): + def setup_buckets(self: 'AiToolkitDataset', quiet=False): if not hasattr(self, 'file_list'): raise Exception(f'file_list not found on class instance {self.__class__.__name__}') if not hasattr(self, 'dataset_config'): raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') + if self.epoch_num > 0 and self.dataset_config.poi is None: + # no need to rebuild buckets for now + # todo handle random cropping for buckets + return + self.buckets = {} # clear it + config: 'DatasetConfig' = self.dataset_config resolution = config.resolution bucket_tolerance = config.bucket_tolerance file_list: List['FileItemDTO'] = self.file_list - total_pixels = resolution * resolution - # for file_item in enumerate(file_list): for idx, file_item in enumerate(file_list): file_item: 'FileItemDTO' = file_item - width = file_item.crop_width - height = file_item.crop_height + width = int(file_item.width * file_item.dataset_config.scale) + height = int(file_item.height * file_item.dataset_config.scale) - bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, - divisibility=bucket_tolerance) - - # set the scaling height and with to match smallest size, and keep aspect ratio - if width > height: - file_item.scale_to_height = bucket_resolution["height"] - file_item.scale_to_width = int(width * (bucket_resolution["height"] / height)) + if file_item.has_point_of_interest: + # let the poi module handle the bucketing + file_item.setup_poi_bucket() else: - file_item.scale_to_width = bucket_resolution["width"] - file_item.scale_to_height = int(height * (bucket_resolution["width"] / width)) + bucket_resolution = get_bucket_for_image_size( + width, height, + resolution=resolution, + divisibility=bucket_tolerance + ) - file_item.crop_height = bucket_resolution["height"] - file_item.crop_width = bucket_resolution["width"] + # Calculate scale factors for width and height + width_scale_factor = bucket_resolution["width"] / width + height_scale_factor = bucket_resolution["height"] / height - new_width = bucket_resolution["width"] - new_height = bucket_resolution["height"] + # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution + max_scale_factor = max(width_scale_factor, height_scale_factor) + + file_item.scale_to_width = int(width * max_scale_factor) + file_item.scale_to_height = int(height * max_scale_factor) + + file_item.crop_height = bucket_resolution["height"] + file_item.crop_width = bucket_resolution["width"] + + new_width = bucket_resolution["width"] + new_height = bucket_resolution["height"] + + if self.dataset_config.random_crop: + # random crop + crop_x = random.randint(0, file_item.scale_to_width - new_width) + crop_y = random.randint(0, file_item.scale_to_height - new_height) + file_item.crop_x = crop_x + file_item.crop_y = crop_y + else: + # do central crop + file_item.crop_x = int((file_item.scale_to_width - new_width) / 2) + file_item.crop_y = int((file_item.scale_to_height - new_height) / 2) + + if file_item.crop_y < 0 or file_item.crop_x < 0: + print('debug') # check if bucket exists, if not, create it - bucket_key = f'{new_width}x{new_height}' + bucket_key = f'{file_item.crop_width}x{file_item.crop_height}' if bucket_key not in self.buckets: - self.buckets[bucket_key] = Bucket(new_width, new_height) + self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height) self.buckets[bucket_key].file_list_idx.append(idx) # print the buckets self.build_batch_indices() - name = f"{os.path.basename(self.dataset_path)} ({self.resolution})" - print(f'Bucket sizes for {self.dataset_path}:') - for key, bucket in self.buckets.items(): - print(f'{key}: {len(bucket.file_list_idx)} files') - print(f'{len(self.buckets)} buckets made') + if not quiet: + print(f'Bucket sizes for {self.dataset_path}:') + for key, bucket in self.buckets.items(): + print(f'{key}: {len(bucket.file_list_idx)} files') + print(f'{len(self.buckets)} buckets made') - # file buckets made class CaptionProcessingDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): @@ -281,10 +313,19 @@ class ImageProcessingDTOMixin: img.transpose(Image.FLIP_TOP_BOTTOM) if self.dataset_config.buckets: - # todo allow scaling and cropping, will be hard to add # scale and crop based on file item img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) - img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height + if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: + print('size mismatch') + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) else: # Downscale the source image first # TODO this is nto right @@ -371,7 +412,14 @@ class ControlFileItemDTOMixin: if self.dataset_config.buckets: # scale and crop based on file item img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) - img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) else: raise Exception("Control images not supported for non-bucket datasets") @@ -381,6 +429,93 @@ class ControlFileItemDTOMixin: self.control_tensor = None +class PoiFileItemDTOMixin: + # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject + # items in the poi will always be inside the image when random cropping + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + # poi is a name of the box point of interest in the caption json file + dataset_config = kwargs.get('dataset_config', None) + path = kwargs.get('path', None) + self.poi: Union[str, None] = dataset_config.poi + self.has_point_of_interest = self.poi is not None + self.poi_x: Union[int, None] = None + self.poi_y: Union[int, None] = None + self.poi_width: Union[int, None] = None + self.poi_height: Union[int, None] = None + + if self.poi is not None: + # make sure latent caching is off + if dataset_config.cache_latents or dataset_config.cache_latents_to_disk: + raise Exception( + f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config" + ) + # make sure we are loading through json + if dataset_config.caption_ext != 'json': + raise Exception( + f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config" + ) + self.poi = self.poi.strip() + # get the caption path + file_path_no_ext = os.path.splitext(path)[0] + caption_path = file_path_no_ext + '.json' + if not os.path.exists(caption_path): + raise Exception(f"Error: caption file not found for poi: {caption_path}") + with open(caption_path, 'r', encoding='utf-8') as f: + json_data = json.load(f) + if 'poi' not in json_data: + raise Exception(f"Error: poi not found in caption file: {caption_path}") + if self.poi not in json_data['poi']: + raise Exception(f"Error: poi not found in caption file: {caption_path}") + # poi has, x, y, width, height + poi = json_data['poi'][self.poi] + self.poi_x = int(poi['x']) + self.poi_y = int(poi['y']) + self.poi_width = int(poi['width']) + self.poi_height = int(poi['height']) + + def setup_poi_bucket(self: 'FileItemDTO'): + # we are using poi, so we need to calculate the bucket based on the poi + + resolution = self.dataset_config.resolution + bucket_tolerance = self.dataset_config.bucket_tolerance + initial_width = int(self.width * self.dataset_config.scale) + initial_height = int(self.height * self.dataset_config.scale) + poi_x = int(self.poi_x * self.dataset_config.scale) + poi_y = int(self.poi_y * self.dataset_config.scale) + poi_width = int(self.poi_width * self.dataset_config.scale) + poi_height = int(self.poi_height * self.dataset_config.scale) + + # todo handle a poi that is smaller than resolution + # determine new cropping + crop_left = random.randint(0, poi_x) + crop_right = random.randint(poi_x + poi_width, initial_width) + crop_top = random.randint(0, poi_y) + crop_bottom = random.randint(poi_y + poi_height, initial_height) + + new_width = crop_right - crop_left + new_height = crop_bottom - crop_top + + bucket_resolution = get_bucket_for_image_size( + new_width, new_height, + resolution=resolution, + divisibility=bucket_tolerance + ) + + width_scale_factor = bucket_resolution["width"] / new_width + height_scale_factor = bucket_resolution["height"] / new_height + # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution + max_scale_factor = max(width_scale_factor, height_scale_factor) + + self.scale_to_width = int(initial_width * max_scale_factor) + self.scale_to_height = int(initial_height * max_scale_factor) + self.crop_width = bucket_resolution['width'] + self.crop_height = bucket_resolution['height'] + self.crop_x = int(crop_left * max_scale_factor) + self.crop_y = int(crop_top * max_scale_factor) + + class ArgBreakMixin: # just stops super calls form hitting object def __init__(self, *args, **kwargs):