mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Numerous fixes for time sampling. Still not perfect
This commit is contained in:
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
|
||||
# https://demo.albumentations.ai/
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
@@ -167,10 +167,11 @@ class BucketsMixin:
|
||||
width = int(file_item.width * file_item.dataset_config.scale)
|
||||
height = int(file_item.height * file_item.dataset_config.scale)
|
||||
|
||||
did_process_poi = False
|
||||
if file_item.has_point_of_interest:
|
||||
# let the poi module handle the bucketing
|
||||
file_item.setup_poi_bucket()
|
||||
else:
|
||||
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
|
||||
did_process_poi = file_item.setup_poi_bucket()
|
||||
if not did_process_poi:
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width, height,
|
||||
resolution=resolution,
|
||||
@@ -323,7 +324,7 @@ class CaptionProcessingDTOMixin:
|
||||
|
||||
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
|
||||
# add random triggers
|
||||
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
|
||||
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
# shuffle again
|
||||
@@ -803,79 +804,68 @@ class PoiFileItemDTOMixin:
|
||||
self.poi_y = self.height - self.poi_y - self.poi_height
|
||||
|
||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||
# we are using poi, so we need to calculate the bucket based on the poi
|
||||
|
||||
# TODO this will allow poi to be smaller than resolution. Could affect training image size
|
||||
poi_resolution = min(
|
||||
self.dataset_config.resolution,
|
||||
get_resolution(
|
||||
self.poi_width * self.dataset_config.scale,
|
||||
self.poi_height * self.dataset_config.scale
|
||||
)
|
||||
)
|
||||
|
||||
resolution = min(self.dataset_config.resolution, poi_resolution)
|
||||
|
||||
bucket_tolerance = self.dataset_config.bucket_tolerance
|
||||
initial_width = int(self.width * self.dataset_config.scale)
|
||||
initial_height = int(self.height * self.dataset_config.scale)
|
||||
# we are using poi, so we need to calculate the bucket based on the poi
|
||||
|
||||
# if img resolution is less than dataset resolution, just return and let the normal bucketing happen
|
||||
img_resolution = get_resolution(initial_width, initial_height)
|
||||
if img_resolution <= self.dataset_config.resolution:
|
||||
return False # will trigger normal bucketing
|
||||
|
||||
bucket_tolerance = self.dataset_config.bucket_tolerance
|
||||
poi_x = int(self.poi_x * self.dataset_config.scale)
|
||||
poi_y = int(self.poi_y * self.dataset_config.scale)
|
||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||
|
||||
# expand poi to fit resolution
|
||||
if poi_width < resolution:
|
||||
width_difference = resolution - poi_width
|
||||
poi_x = poi_x - int(width_difference / 2)
|
||||
poi_width = resolution
|
||||
# make sure we dont go out of bounds
|
||||
if poi_x < 0:
|
||||
# loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better
|
||||
num_loops = 0
|
||||
while True:
|
||||
# crop left
|
||||
if poi_x > 0:
|
||||
poi_x = random.randint(0, poi_x)
|
||||
else:
|
||||
poi_x = 0
|
||||
# if total width too much, crop
|
||||
if poi_x + poi_width > initial_width:
|
||||
poi_width = initial_width - poi_x
|
||||
|
||||
if poi_height < resolution:
|
||||
height_difference = resolution - poi_height
|
||||
poi_y = poi_y - int(height_difference / 2)
|
||||
poi_height = resolution
|
||||
# make sure we dont go out of bounds
|
||||
if poi_y < 0:
|
||||
# crop right
|
||||
cr_min = poi_x + poi_width
|
||||
if cr_min < initial_width:
|
||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||
else:
|
||||
crop_right = initial_width
|
||||
|
||||
poi_width = crop_right - poi_x
|
||||
|
||||
if poi_y > 0:
|
||||
poi_y = random.randint(0, poi_y)
|
||||
else:
|
||||
poi_y = 0
|
||||
# if total height too much, crop
|
||||
if poi_y + poi_height > initial_height:
|
||||
poi_height = initial_height - poi_y
|
||||
|
||||
# crop left
|
||||
if poi_x > 0:
|
||||
crop_left = random.randint(0, poi_x)
|
||||
else:
|
||||
crop_left = 0
|
||||
if poi_y + poi_height < initial_height:
|
||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||
else:
|
||||
crop_bottom = initial_height
|
||||
|
||||
# crop right
|
||||
cr_min = poi_x + poi_width
|
||||
if cr_min < initial_width:
|
||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||
else:
|
||||
crop_right = initial_width
|
||||
poi_height = crop_bottom - poi_y
|
||||
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
|
||||
current_resolution = get_resolution(poi_width, poi_height)
|
||||
if current_resolution >= self.dataset_config.resolution:
|
||||
# We can break now
|
||||
break
|
||||
else:
|
||||
num_loops += 1
|
||||
if num_loops > 100:
|
||||
print(
|
||||
f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
|
||||
return False
|
||||
|
||||
if poi_y > 0:
|
||||
crop_top = random.randint(0, poi_y)
|
||||
else:
|
||||
crop_top = 0
|
||||
|
||||
if poi_y + poi_height < initial_height:
|
||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||
else:
|
||||
crop_bottom = initial_height
|
||||
|
||||
new_width = crop_right - crop_left
|
||||
new_height = crop_bottom - crop_top
|
||||
new_width = poi_width
|
||||
new_height = poi_height
|
||||
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
new_width, new_height,
|
||||
resolution=resolution,
|
||||
resolution=self.dataset_config.resolution,
|
||||
divisibility=bucket_tolerance
|
||||
)
|
||||
|
||||
@@ -888,8 +878,10 @@ class PoiFileItemDTOMixin:
|
||||
self.scale_to_height = int(initial_height * max_scale_factor)
|
||||
self.crop_width = bucket_resolution['width']
|
||||
self.crop_height = bucket_resolution['height']
|
||||
self.crop_x = int(crop_left * max_scale_factor)
|
||||
self.crop_y = int(crop_top * max_scale_factor)
|
||||
self.crop_x = int(poi_x * max_scale_factor)
|
||||
self.crop_y = int(poi_y * max_scale_factor)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ArgBreakMixin:
|
||||
|
||||
Reference in New Issue
Block a user