From 085787b799f17d7f68f4751f1e77bf281107ad7c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 30 Sep 2023 07:28:23 -0600 Subject: [PATCH] Allow loading auxillery images from dataloader --- toolkit/config_modules.py | 2 +- toolkit/data_transfer_object/data_loader.py | 20 ++++-- toolkit/dataloader_mixins.py | 75 ++++++++++++++++++++- 3 files changed, 89 insertions(+), 8 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 7fda9382..a78252ea 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -231,10 +231,10 @@ class DatasetConfig: self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) - self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) + self.control_path: str = kwargs.get('control_path', None) # depth maps, etc # cache latents will store them in memory self.cache_latents: bool = kwargs.get('cache_latents', False) diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 842375f0..98f73bfd 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -6,7 +6,8 @@ from PIL import Image from PIL.ImageOps import exif_transpose from toolkit import image_utils -from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin +from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ + ControlFileItemDTOMixin, ArgBreakMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -21,9 +22,15 @@ def print_once(msg): printed_messages.append(msg) -class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, ImageProcessingDTOMixin): - def __init__(self, **kwargs): - super().__init__() +class FileItemDTO( + LatentCachingFileItemDTOMixin, + CaptionProcessingDTOMixin, + ImageProcessingDTOMixin, + ControlFileItemDTOMixin, + ArgBreakMixin, +): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.path = kwargs.get('path', None) self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) # process width and height @@ -58,6 +65,7 @@ class FileItemDTO(LatentCachingFileItemDTOMixin, CaptionProcessingDTOMixin, Imag def cleanup(self): self.tensor = None self.cleanup_latent() + self.cleanup_control() class DataLoaderBatchDTO: @@ -73,6 +81,9 @@ class DataLoaderBatchDTO: self.latents: Union[torch.Tensor, None] = None if is_latents_cached: self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) + self.control_tensor: Union[torch.Tensor, None] = None + if self.file_items[0].control_tensor is not None: + self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items]) def get_is_reg_list(self): return [x.is_reg for x in self.file_items] @@ -95,5 +106,6 @@ class DataLoaderBatchDTO: def cleanup(self): del self.latents del self.tensor + del self.control_tensor for file_item in self.file_items: file_item.cleanup() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 4ea24e20..7742d455 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -121,7 +121,8 @@ class BucketsMixin: width = file_item.crop_width height = file_item.crop_height - bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, divisibility=bucket_tolerance) + bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, + divisibility=bucket_tolerance) # set the scaling height and with to match smallest size, and keep aspect ratio if width > height: @@ -239,6 +240,8 @@ class ImageProcessingDTOMixin: # if we are caching latents, just do that if self.is_latent_cached: self.get_latent() + if self.has_control_image: + self.load_control_image() return try: img = Image.open(self.path).convert('RGB') @@ -302,13 +305,79 @@ class ImageProcessingDTOMixin: img = transform(img) self.tensor = img + if self.has_control_image: + self.load_control_image() + + +class ControlFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_control_image = False + self.control_path: Union[str, None] = None + self.control_tensor: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.control_path is not None: + # find the control image path + control_path = dataset_config.control_path + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): + self.control_path = os.path.join(control_path, file_name_no_ext + ext) + self.has_control_image = True + break + + def load_control_image(self: 'FileItemDTO'): + try: + img = Image.open(self.control_path).convert('RGB') + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.control_path}") + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + else: + raise Exception("Control images not supported for non-bucket datasets") + + self.control_tensor = transforms.ToTensor()(img) + + def cleanup_control(self: 'FileItemDTO'): + self.control_tensor = None + + +class ArgBreakMixin: + # just stops super calls form hitting object + def __init__(self, *args, **kwargs): + pass class LatentCachingFileItemDTOMixin: - def __init__(self): + def __init__(self, *args, **kwargs): # if we have super, call it if hasattr(super(), '__init__'): - super().__init__() + super().__init__(*args, **kwargs) self._encoded_latent: Union[torch.Tensor, None] = None self._latent_path: Union[str, None] = None self.is_latent_cached = False