allow smaller images in buckets and bucket them

This commit is contained in:
Jaret Burkett
2023-09-10 03:43:02 -06:00
parent 626ed2939a
commit 41a3f63b72
2 changed files with 18 additions and 6 deletions

View File

@@ -72,15 +72,27 @@ def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[Bucke
return bucket_size_list return bucket_size_list
def get_resolution(width, height):
num_pixels = width * height
# determine same number of pixels for square image
square_resolution = int(num_pixels ** 0.5)
return square_resolution
def get_bucket_for_image_size( def get_bucket_for_image_size(
width: int, width: int,
height: int, height: int,
bucket_size_list: List[BucketResolution] = None, bucket_size_list: List[BucketResolution] = None,
resolution: Union[int, None] = None resolution: Union[int, None] = None
) -> BucketResolution: ) -> BucketResolution:
if bucket_size_list is None and resolution is None: if bucket_size_list is None and resolution is None:
raise ValueError("Must provide either bucket_size_list or resolution") # get resolution from width and height
resolution = get_resolution(width, height)
if bucket_size_list is None: if bucket_size_list is None:
# if real resolution is smaller, use that instead
real_resolution = get_resolution(width, height)
resolution = min(resolution, real_resolution)
bucket_size_list = get_bucket_sizes(resolution=resolution) bucket_size_list = get_bucket_sizes(resolution=resolution)
# Check for exact match first # Check for exact match first

View File

@@ -331,17 +331,17 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
path=file, path=file,
dataset_config=dataset_config dataset_config=dataset_config
) )
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution: # if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
bad_count += 1 # bad_count += 1
else: # else:
self.file_list.append(file_item) self.file_list.append(file_item)
except Exception as e: except Exception as e:
print(f"Error processing image: {file}") print(f"Error processing image: {file}")
print(e) print(e)
bad_count += 1 bad_count += 1
print(f" - Found {len(self.file_list)} images") print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small") # print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
self.setup_epoch() self.setup_epoch()