Numerous fixes for time sampling. Still not perfect

This commit is contained in:
Jaret Burkett
2023-11-28 07:34:43 -07:00
parent d7e55b6ad4
commit 792a5e37e2
7 changed files with 160 additions and 91 deletions

View File

@@ -30,7 +30,7 @@ if TYPE_CHECKING:
# def get_associated_caption_from_img_path(img_path):
# https://demo.albumentations.ai/
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
@@ -167,10 +167,11 @@ class BucketsMixin:
width = int(file_item.width * file_item.dataset_config.scale)
height = int(file_item.height * file_item.dataset_config.scale)
did_process_poi = False
if file_item.has_point_of_interest:
# let the poi module handle the bucketing
file_item.setup_poi_bucket()
else:
# Attempt to process the poi if we can. It wont process if the image is smaller than the resolution
did_process_poi = file_item.setup_poi_bucket()
if not did_process_poi:
bucket_resolution = get_bucket_for_image_size(
width, height,
resolution=resolution,
@@ -323,7 +324,7 @@ class CaptionProcessingDTOMixin:
if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0:
# add random triggers
caption = random.choice(self.dataset_config.random_triggers) + ', ' + caption
caption = caption + ', ' + random.choice(self.dataset_config.random_triggers)
if self.dataset_config.shuffle_tokens:
# shuffle again
@@ -803,79 +804,68 @@ class PoiFileItemDTOMixin:
self.poi_y = self.height - self.poi_y - self.poi_height
def setup_poi_bucket(self: 'FileItemDTO'):
# we are using poi, so we need to calculate the bucket based on the poi
# TODO this will allow poi to be smaller than resolution. Could affect training image size
poi_resolution = min(
self.dataset_config.resolution,
get_resolution(
self.poi_width * self.dataset_config.scale,
self.poi_height * self.dataset_config.scale
)
)
resolution = min(self.dataset_config.resolution, poi_resolution)
bucket_tolerance = self.dataset_config.bucket_tolerance
initial_width = int(self.width * self.dataset_config.scale)
initial_height = int(self.height * self.dataset_config.scale)
# we are using poi, so we need to calculate the bucket based on the poi
# if img resolution is less than dataset resolution, just return and let the normal bucketing happen
img_resolution = get_resolution(initial_width, initial_height)
if img_resolution <= self.dataset_config.resolution:
return False # will trigger normal bucketing
bucket_tolerance = self.dataset_config.bucket_tolerance
poi_x = int(self.poi_x * self.dataset_config.scale)
poi_y = int(self.poi_y * self.dataset_config.scale)
poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale)
# expand poi to fit resolution
if poi_width < resolution:
width_difference = resolution - poi_width
poi_x = poi_x - int(width_difference / 2)
poi_width = resolution
# make sure we dont go out of bounds
if poi_x < 0:
# loop to keep expanding until we are at the proper resolution. This is not ideal, we can probably handle it better
num_loops = 0
while True:
# crop left
if poi_x > 0:
poi_x = random.randint(0, poi_x)
else:
poi_x = 0
# if total width too much, crop
if poi_x + poi_width > initial_width:
poi_width = initial_width - poi_x
if poi_height < resolution:
height_difference = resolution - poi_height
poi_y = poi_y - int(height_difference / 2)
poi_height = resolution
# make sure we dont go out of bounds
if poi_y < 0:
# crop right
cr_min = poi_x + poi_width
if cr_min < initial_width:
crop_right = random.randint(poi_x + poi_width, initial_width)
else:
crop_right = initial_width
poi_width = crop_right - poi_x
if poi_y > 0:
poi_y = random.randint(0, poi_y)
else:
poi_y = 0
# if total height too much, crop
if poi_y + poi_height > initial_height:
poi_height = initial_height - poi_y
# crop left
if poi_x > 0:
crop_left = random.randint(0, poi_x)
else:
crop_left = 0
if poi_y + poi_height < initial_height:
crop_bottom = random.randint(poi_y + poi_height, initial_height)
else:
crop_bottom = initial_height
# crop right
cr_min = poi_x + poi_width
if cr_min < initial_width:
crop_right = random.randint(poi_x + poi_width, initial_width)
else:
crop_right = initial_width
poi_height = crop_bottom - poi_y
# now we have our random crop, but it may be smaller than resolution. Check and expand if needed
current_resolution = get_resolution(poi_width, poi_height)
if current_resolution >= self.dataset_config.resolution:
# We can break now
break
else:
num_loops += 1
if num_loops > 100:
print(
f"Warning: poi bucketing looped too many times. This should not happen. Please report this issue.")
return False
if poi_y > 0:
crop_top = random.randint(0, poi_y)
else:
crop_top = 0
if poi_y + poi_height < initial_height:
crop_bottom = random.randint(poi_y + poi_height, initial_height)
else:
crop_bottom = initial_height
new_width = crop_right - crop_left
new_height = crop_bottom - crop_top
new_width = poi_width
new_height = poi_height
bucket_resolution = get_bucket_for_image_size(
new_width, new_height,
resolution=resolution,
resolution=self.dataset_config.resolution,
divisibility=bucket_tolerance
)
@@ -888,8 +878,10 @@ class PoiFileItemDTOMixin:
self.scale_to_height = int(initial_height * max_scale_factor)
self.crop_width = bucket_resolution['width']
self.crop_height = bucket_resolution['height']
self.crop_x = int(crop_left * max_scale_factor)
self.crop_y = int(crop_top * max_scale_factor)
self.crop_x = int(poi_x * max_scale_factor)
self.crop_y = int(poi_y * max_scale_factor)
return True
class ArgBreakMixin: