From bedb8197a29cf885b284e137a397d9bb8de620ec Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 20 Oct 2024 11:51:29 -0600 Subject: [PATCH] Fixed issue with sizes for some images being loaded sideways resulting in squished images. --- toolkit/data_loader.py | 9 +++++++ toolkit/data_transfer_object/data_loader.py | 28 +++++++++++++-------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 103125e..5285b37 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -441,16 +441,24 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti if not os.path.isdir(self.dataset_path): dataset_folder = os.path.dirname(dataset_folder) dataset_size_file = os.path.join(dataset_folder, '.aitk_size.json') + dataloader_version = "0.1.1" if os.path.exists(dataset_size_file): try: with open(dataset_size_file, 'r') as f: self.size_database = json.load(f) + + if "__version__" not in self.size_database or self.size_database["__version__"] != dataloader_version: + print("Upgrading size database to new version") + # old version, delete and recreate + self.size_database = {} except Exception as e: print(f"Error loading size database: {dataset_size_file}") print(e) self.size_database = {} else: self.size_database = {} + + self.size_database["__version__"] = dataloader_version bad_count = 0 for file in tqdm(file_list): @@ -461,6 +469,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti dataset_config=dataset_config, dataloader_transforms=self.transform, size_database=self.size_database, + dataset_root=dataset_folder, ) self.file_list.append(file_item) except Exception as e: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 8bc0314..34239f4 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -44,19 +44,25 @@ class FileItemDTO( self.path = kwargs.get('path', '') self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) size_database = kwargs.get('size_database', {}) - filename = os.path.basename(self.path) - if filename in size_database: - w, h = size_database[filename] + dataset_root = kwargs.get('dataset_root', None) + if dataset_root is not None: + # remove dataset root from path + file_key = self.path.replace(dataset_root, '') else: + file_key = os.path.basename(self.path) + if file_key in size_database: + w, h = size_database[file_key] + else: + # original method is significantly faster, but some images are read sideways. Not sure why. Do slow method for now. # process width and height - try: - w, h = image_utils.get_image_size(self.path) - except image_utils.UnknownImageFormat: - print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ - f'This process is faster for png, jpeg') - img = exif_transpose(Image.open(self.path)) - w, h = img.size - size_database[filename] = (w, h) + # try: + # w, h = image_utils.get_image_size(self.path) + # except image_utils.UnknownImageFormat: + # print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + # f'This process is faster for png, jpeg') + img = exif_transpose(Image.open(self.path)) + w, h = img.size + size_database[file_key] = (w, h) self.width: int = w self.height: int = h self.dataloader_transforms = kwargs.get('dataloader_transforms', None)