import os from typing import TYPE_CHECKING, List, Dict class CaptionMixin: def get_caption_item(self, index): if not hasattr(self, 'caption_type'): raise Exception('caption_type not found on class instance') if not hasattr(self, 'file_list'): raise Exception('file_list not found on class instance') img_path_or_tuple = self.file_list[index] if isinstance(img_path_or_tuple, tuple): img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path # check if either has a prompt file path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' if not os.path.exists(prompt_path): img_path = img_path_or_tuple[1] if isinstance(img_path_or_tuple[1], str) else img_path_or_tuple[1].path path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' else: img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path # see if prompt file exists path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() # remove any newlines prompt = prompt.replace('\n', ', ') # remove new lines for all operating systems prompt = prompt.replace('\r', ', ') prompt_split = prompt.split(',') # remove empty strings prompt_split = [p.strip() for p in prompt_split if p.strip()] # join back together prompt = ', '.join(prompt_split) else: prompt = '' # get default_prompt if it exists on the class instance if hasattr(self, 'default_prompt'): prompt = self.default_prompt if hasattr(self, 'default_caption'): prompt = self.default_caption return prompt if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig from toolkit.data_loader import FileItem class Bucket: def __init__(self, width: int, height: int): self.width = width self.height = height self.file_list_idx: List[int] = [] class BucketsMixin: def __init__(self): self.buckets: Dict[str, Bucket] = {} self.batch_indices: List[List[int]] = [] def build_batch_indices(self): for key, bucket in self.buckets.items(): for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) batch = bucket.file_list_idx[start_idx:end_idx] self.batch_indices.append(batch) def setup_buckets(self): if not hasattr(self, 'file_list'): raise Exception(f'file_list not found on class instance {self.__class__.__name__}') if not hasattr(self, 'dataset_config'): raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') config: 'DatasetConfig' = self.dataset_config resolution = config.resolution bucket_tolerance = config.bucket_tolerance file_list: List['FileItem'] = self.file_list # make sure out resolution is divisible by bucket_tolerance if resolution % bucket_tolerance != 0: # reduce it to the nearest divisible number resolution = resolution - (resolution % bucket_tolerance) # for file_item in enumerate(file_list): for idx, file_item in enumerate(file_list): width = file_item.crop_width height = file_item.crop_height # determine new size, smallest dimension should be equal to resolution # the other dimension should be the same ratio it is now (bigger) new_width = resolution new_height = resolution new_x = file_item.crop_x new_y = file_item.crop_y if width > height: # scale width to match new resolution, new_width = int(width * (resolution / height)) # make sure new_width is divisible by bucket_tolerance if new_width % bucket_tolerance != 0: # reduce it to the nearest divisible number reduction = new_width % bucket_tolerance new_width = new_width - reduction # adjust the new x position so we evenly crop new_x = int(new_x + (reduction / 2)) elif height > width: # scale height to match new resolution new_height = int(height * (resolution / width)) # make sure new_height is divisible by bucket_tolerance if new_height % bucket_tolerance != 0: # reduce it to the nearest divisible number reduction = new_height % bucket_tolerance new_height = new_height - reduction # adjust the new x position so we evenly crop new_y = int(new_y + (reduction / 2)) # add info to file file_item.crop_x = new_x file_item.crop_y = new_y file_item.crop_width = new_width file_item.crop_height = new_height # check if bucket exists, if not, create it bucket_key = f'{new_width}x{new_height}' if bucket_key not in self.buckets: self.buckets[bucket_key] = Bucket(new_width, new_height) self.buckets[bucket_key].file_list_idx.append(idx) # print the buckets self.build_batch_indices() print(f'Bucket sizes for {self.__class__.__name__}:') for key, bucket in self.buckets.items(): print(f'{key}: {len(bucket.file_list_idx)} files') print(f'{len(self.buckets)} buckets made') # file buckets made