mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
allow smaller images in buckets and bucket them
This commit is contained in:
@@ -72,15 +72,27 @@ def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[Bucke
|
|||||||
return bucket_size_list
|
return bucket_size_list
|
||||||
|
|
||||||
|
|
||||||
|
def get_resolution(width, height):
|
||||||
|
num_pixels = width * height
|
||||||
|
# determine same number of pixels for square image
|
||||||
|
square_resolution = int(num_pixels ** 0.5)
|
||||||
|
return square_resolution
|
||||||
|
|
||||||
|
|
||||||
def get_bucket_for_image_size(
|
def get_bucket_for_image_size(
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
bucket_size_list: List[BucketResolution] = None,
|
bucket_size_list: List[BucketResolution] = None,
|
||||||
resolution: Union[int, None] = None
|
resolution: Union[int, None] = None
|
||||||
) -> BucketResolution:
|
) -> BucketResolution:
|
||||||
|
|
||||||
if bucket_size_list is None and resolution is None:
|
if bucket_size_list is None and resolution is None:
|
||||||
raise ValueError("Must provide either bucket_size_list or resolution")
|
# get resolution from width and height
|
||||||
|
resolution = get_resolution(width, height)
|
||||||
if bucket_size_list is None:
|
if bucket_size_list is None:
|
||||||
|
# if real resolution is smaller, use that instead
|
||||||
|
real_resolution = get_resolution(width, height)
|
||||||
|
resolution = min(resolution, real_resolution)
|
||||||
bucket_size_list = get_bucket_sizes(resolution=resolution)
|
bucket_size_list = get_bucket_sizes(resolution=resolution)
|
||||||
|
|
||||||
# Check for exact match first
|
# Check for exact match first
|
||||||
|
|||||||
@@ -331,17 +331,17 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
|||||||
path=file,
|
path=file,
|
||||||
dataset_config=dataset_config
|
dataset_config=dataset_config
|
||||||
)
|
)
|
||||||
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
|
# if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
|
||||||
bad_count += 1
|
# bad_count += 1
|
||||||
else:
|
# else:
|
||||||
self.file_list.append(file_item)
|
self.file_list.append(file_item)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing image: {file}")
|
print(f"Error processing image: {file}")
|
||||||
print(e)
|
print(e)
|
||||||
bad_count += 1
|
bad_count += 1
|
||||||
|
|
||||||
print(f" - Found {len(self.file_list)} images")
|
print(f" - Found {len(self.file_list)} images")
|
||||||
print(f" - Found {bad_count} images that are too small")
|
# print(f" - Found {bad_count} images that are too small")
|
||||||
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
|
||||||
|
|
||||||
self.setup_epoch()
|
self.setup_epoch()
|
||||||
|
|||||||
Reference in New Issue
Block a user