diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index a09ed5d8..7fd5295b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -105,46 +105,51 @@ class BucketsMixin: # determine new resolution to have the same number of pixels current_pixels = width * height if current_pixels == total_pixels: - # no change - continue - - aspect_ratio = width / height - new_height = int(math.sqrt(total_pixels / aspect_ratio)) - new_width = int(aspect_ratio * new_height) - - # increase smallest one to be divisible by bucket_tolerance and increase the other to match - if new_width < new_height: - # increase width - if new_width % bucket_tolerance != 0: - crop_amount = new_width % bucket_tolerance - new_width = new_width + (bucket_tolerance - crop_amount) - new_height = int(new_width / aspect_ratio) + file_item.scale_to_width = width + file_item.scale_to_height = height + file_item.crop_width = width + file_item.crop_height = height + new_width = width + new_height = height else: - # increase height - if new_height % bucket_tolerance != 0: - crop_amount = new_height % bucket_tolerance - new_height = new_height + (bucket_tolerance - crop_amount) + + aspect_ratio = width / height + new_height = int(math.sqrt(total_pixels / aspect_ratio)) new_width = int(aspect_ratio * new_height) - # Ensure that the total number of pixels remains the same. - # assert new_width * new_height == total_pixels + # increase smallest one to be divisible by bucket_tolerance and increase the other to match + if new_width < new_height: + # increase width + if new_width % bucket_tolerance != 0: + crop_amount = new_width % bucket_tolerance + new_width = new_width + (bucket_tolerance - crop_amount) + new_height = int(new_width / aspect_ratio) + else: + # increase height + if new_height % bucket_tolerance != 0: + crop_amount = new_height % bucket_tolerance + new_height = new_height + (bucket_tolerance - crop_amount) + new_width = int(aspect_ratio * new_height) - file_item.scale_to_width = new_width - file_item.scale_to_height = new_height - file_item.crop_width = new_width - file_item.crop_height = new_height - # make sure it is divisible by bucket_tolerance, decrease if not - if new_width % bucket_tolerance != 0: - crop_amount = new_width % bucket_tolerance - file_item.crop_width = new_width - crop_amount - else: + # Ensure that the total number of pixels remains the same. + # assert new_width * new_height == total_pixels + + file_item.scale_to_width = new_width + file_item.scale_to_height = new_height file_item.crop_width = new_width - - if new_height % bucket_tolerance != 0: - crop_amount = new_height % bucket_tolerance - file_item.crop_height = new_height - crop_amount - else: file_item.crop_height = new_height + # make sure it is divisible by bucket_tolerance, decrease if not + if new_width % bucket_tolerance != 0: + crop_amount = new_width % bucket_tolerance + file_item.crop_width = new_width - crop_amount + else: + file_item.crop_width = new_width + + if new_height % bucket_tolerance != 0: + crop_amount = new_height % bucket_tolerance + file_item.crop_height = new_height - crop_amount + else: + file_item.crop_height = new_height # check if bucket exists, if not, create it bucket_key = f'{new_width}x{new_height}' @@ -154,7 +159,8 @@ class BucketsMixin: # print the buckets self.build_batch_indices() - print(f'Bucket sizes for {self.__class__.__name__}:') + name = f"{os.path.basename(self.dataset_path)} ({self.resolution})" + print(f'Bucket sizes for {self.dataset_path}:') for key, bucket in self.buckets.items(): print(f'{key}: {len(bucket.file_list_idx)} files') print(f'{len(self.buckets)} buckets made')