import base64 import hashlib import json import math import os import random from collections import OrderedDict from typing import TYPE_CHECKING, List, Dict, Union import torch from safetensors.torch import load_file, save_file from tqdm import tqdm from toolkit.basic import flush from toolkit.buckets import get_bucket_for_image_size from toolkit.metadata import get_meta_for_safetensors from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms from PIL import Image from PIL.ImageOps import exif_transpose from toolkit.train_tools import get_torch_dtype if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset from toolkit.data_transfer_object.data_loader import FileItemDTO # def get_associated_caption_from_img_path(img_path): transforms_dict = { 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03), 'RandomEqualize': transforms.RandomEqualize(p=0.2), } caption_ext_list = ['txt', 'json', 'caption'] class CaptionMixin: def get_caption_item(self: 'AiToolkitDataset', index): if not hasattr(self, 'caption_type'): raise Exception('caption_type not found on class instance') if not hasattr(self, 'file_list'): raise Exception('file_list not found on class instance') img_path_or_tuple = self.file_list[index] if isinstance(img_path_or_tuple, tuple): img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path # check if either has a prompt file path_no_ext = os.path.splitext(img_path)[0] prompt_path = None for ext in caption_ext_list: prompt_path = path_no_ext + '.' + ext if os.path.exists(prompt_path): break else: img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path # see if prompt file exists path_no_ext = os.path.splitext(img_path)[0] prompt_path = None for ext in caption_ext_list: prompt_path = path_no_ext + '.' + ext if os.path.exists(prompt_path): break if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() # check if is json if prompt_path.endswith('.json'): prompt = json.loads(prompt) if 'caption' in prompt: prompt = prompt['caption'] # remove any newlines prompt = prompt.replace('\n', ', ') # remove new lines for all operating systems prompt = prompt.replace('\r', ', ') prompt_split = prompt.split(',') # remove empty strings prompt_split = [p.strip() for p in prompt_split if p.strip()] # join back together prompt = ', '.join(prompt_split) else: prompt = '' # get default_prompt if it exists on the class instance if hasattr(self, 'default_prompt'): prompt = self.default_prompt if hasattr(self, 'default_caption'): prompt = self.default_caption return prompt if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig from toolkit.data_transfer_object.data_loader import FileItemDTO class Bucket: def __init__(self, width: int, height: int): self.width = width self.height = height self.file_list_idx: List[int] = [] class BucketsMixin: def __init__(self): self.buckets: Dict[str, Bucket] = {} self.batch_indices: List[List[int]] = [] def build_batch_indices(self: 'AiToolkitDataset'): self.batch_indices = [] for key, bucket in self.buckets.items(): for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) batch = bucket.file_list_idx[start_idx:end_idx] self.batch_indices.append(batch) def setup_buckets(self: 'AiToolkitDataset', quiet=False): if not hasattr(self, 'file_list'): raise Exception(f'file_list not found on class instance {self.__class__.__name__}') if not hasattr(self, 'dataset_config'): raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') if self.epoch_num > 0 and self.dataset_config.poi is None: # no need to rebuild buckets for now # todo handle random cropping for buckets return self.buckets = {} # clear it config: 'DatasetConfig' = self.dataset_config resolution = config.resolution bucket_tolerance = config.bucket_tolerance file_list: List['FileItemDTO'] = self.file_list # for file_item in enumerate(file_list): for idx, file_item in enumerate(file_list): file_item: 'FileItemDTO' = file_item width = int(file_item.width * file_item.dataset_config.scale) height = int(file_item.height * file_item.dataset_config.scale) if file_item.has_point_of_interest: # let the poi module handle the bucketing file_item.setup_poi_bucket() else: bucket_resolution = get_bucket_for_image_size( width, height, resolution=resolution, divisibility=bucket_tolerance ) # Calculate scale factors for width and height width_scale_factor = bucket_resolution["width"] / width height_scale_factor = bucket_resolution["height"] / height # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution max_scale_factor = max(width_scale_factor, height_scale_factor) file_item.scale_to_width = int(width * max_scale_factor) file_item.scale_to_height = int(height * max_scale_factor) file_item.crop_height = bucket_resolution["height"] file_item.crop_width = bucket_resolution["width"] new_width = bucket_resolution["width"] new_height = bucket_resolution["height"] if self.dataset_config.random_crop: # random crop crop_x = random.randint(0, file_item.scale_to_width - new_width) crop_y = random.randint(0, file_item.scale_to_height - new_height) file_item.crop_x = crop_x file_item.crop_y = crop_y else: # do central crop file_item.crop_x = int((file_item.scale_to_width - new_width) / 2) file_item.crop_y = int((file_item.scale_to_height - new_height) / 2) if file_item.crop_y < 0 or file_item.crop_x < 0: print('debug') # check if bucket exists, if not, create it bucket_key = f'{file_item.crop_width}x{file_item.crop_height}' if bucket_key not in self.buckets: self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height) self.buckets[bucket_key].file_list_idx.append(idx) # print the buckets self.build_batch_indices() if not quiet: print(f'Bucket sizes for {self.dataset_path}:') for key, bucket in self.buckets.items(): print(f'{key}: {len(bucket.file_list_idx)} files') print(f'{len(self.buckets)} buckets made') class CaptionProcessingDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): if self.raw_caption is not None: # we already loaded it pass elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]: self.raw_caption = caption_dict[self.path]["caption"] else: # see if prompt file exists path_no_ext = os.path.splitext(self.path)[0] prompt_ext = self.dataset_config.caption_ext prompt_path = f"{path_no_ext}.{prompt_ext}" if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() if prompt_path.endswith('.json'): prompt = json.loads(prompt) if 'caption' in prompt: prompt = prompt['caption'] # remove any newlines prompt = prompt.replace('\n', ', ') # remove new lines for all operating systems prompt = prompt.replace('\r', ', ') prompt_split = prompt.split(',') # remove empty strings prompt_split = [p.strip() for p in prompt_split if p.strip()] # join back together prompt = ', '.join(prompt_split) else: prompt = '' if self.dataset_config.default_caption is not None: prompt = self.dataset_config.default_caption self.raw_caption = prompt def get_caption( self: 'FileItemDTO', trigger=None, to_replace_list=None, add_if_not_present=False ): raw_caption = self.raw_caption if raw_caption is None: raw_caption = '' # handle dropout if self.dataset_config.caption_dropout_rate > 0: # get a random float form 0 to 1 rand = random.random() if rand < self.dataset_config.caption_dropout_rate: # drop the caption return '' # get tokens token_list = raw_caption.split(',') # trim whitespace token_list = [x.strip() for x in token_list] # remove empty strings token_list = [x for x in token_list if x] if self.dataset_config.shuffle_tokens: random.shuffle(token_list) # handle token dropout if self.dataset_config.token_dropout_rate > 0: new_token_list = [] for token in token_list: # get a random float form 0 to 1 rand = random.random() if rand > self.dataset_config.token_dropout_rate: # keep the token new_token_list.append(token) token_list = new_token_list # join back together caption = ', '.join(token_list) caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) return caption class ImageProcessingDTOMixin: def load_and_process_image( self: 'FileItemDTO', transform: Union[None, transforms.Compose] ): # if we are caching latents, just do that if self.is_latent_cached: self.get_latent() if self.has_control_image: self.load_control_image() return try: img = Image.open(self.path).convert('RGB') img = exif_transpose(img) except Exception as e: print(f"Error: {e}") print(f"Error loading image: {self.path}") w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match raise ValueError( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match raise ValueError( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") if self.flip_x: # do a flip img.transpose(Image.FLIP_LEFT_RIGHT) if self.flip_y: # do a flip img.transpose(Image.FLIP_TOP_BOTTOM) if self.dataset_config.buckets: # scale and crop based on file item img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height: # todo look into this. This still happens sometimes print('size mismatch') img = img.crop(( self.crop_x, self.crop_y, self.crop_x + self.crop_width, self.crop_y + self.crop_height )) # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) else: # Downscale the source image first # TODO this is nto right img = img.resize( (int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)), Image.BICUBIC) min_img_size = min(img.size) if self.dataset_config.random_crop: if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution: if min_img_size < self.dataset_config.resolution: print( f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}") scale_size = self.dataset_config.resolution else: scale_size = random.randint(self.dataset_config.resolution, int(min_img_size)) scaler = scale_size / min_img_size scale_width = int((img.width + 5) * scaler) scale_height = int((img.height + 5) * scaler) img = img.resize((scale_width, scale_height), Image.BICUBIC) img = transforms.RandomCrop(self.dataset_config.resolution)(img) else: img = transforms.CenterCrop(min_img_size)(img) img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC) if self.augments is not None and len(self.augments) > 0: # do augmentations for augment in self.augments: if augment in transforms_dict: img = transforms_dict[augment](img) if transform: img = transform(img) self.tensor = img if self.has_control_image: self.load_control_image() class ControlFileItemDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) self.has_control_image = False self.control_path: Union[str, None] = None self.control_tensor: Union[torch.Tensor, None] = None dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) if dataset_config.control_path is not None: # find the control image path control_path = dataset_config.control_path # we are using control images img_path = kwargs.get('path', None) img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] for ext in img_ext_list: if os.path.exists(os.path.join(control_path, file_name_no_ext + ext)): self.control_path = os.path.join(control_path, file_name_no_ext + ext) self.has_control_image = True break def load_control_image(self: 'FileItemDTO'): try: img = Image.open(self.control_path).convert('RGB') img = exif_transpose(img) except Exception as e: print(f"Error: {e}") print(f"Error loading image: {self.control_path}") w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match raise ValueError( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") elif h > w and self.scale_to_height < self.scale_to_width: # throw error, they should match raise ValueError( f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") if self.flip_x: # do a flip img.transpose(Image.FLIP_LEFT_RIGHT) if self.flip_y: # do a flip img.transpose(Image.FLIP_TOP_BOTTOM) if self.dataset_config.buckets: # scale and crop based on file item img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) # crop img = img.crop(( self.crop_x, self.crop_y, self.crop_x + self.crop_width, self.crop_y + self.crop_height )) else: raise Exception("Control images not supported for non-bucket datasets") self.control_tensor = transforms.ToTensor()(img) def cleanup_control(self: 'FileItemDTO'): self.control_tensor = None class PoiFileItemDTOMixin: # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject # items in the poi will always be inside the image when random cropping def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) # poi is a name of the box point of interest in the caption json file dataset_config = kwargs.get('dataset_config', None) path = kwargs.get('path', None) self.poi: Union[str, None] = dataset_config.poi self.has_point_of_interest = self.poi is not None self.poi_x: Union[int, None] = None self.poi_y: Union[int, None] = None self.poi_width: Union[int, None] = None self.poi_height: Union[int, None] = None if self.poi is not None: # make sure latent caching is off if dataset_config.cache_latents or dataset_config.cache_latents_to_disk: raise Exception( f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config" ) # make sure we are loading through json if dataset_config.caption_ext != 'json': raise Exception( f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config" ) self.poi = self.poi.strip() # get the caption path file_path_no_ext = os.path.splitext(path)[0] caption_path = file_path_no_ext + '.json' if not os.path.exists(caption_path): raise Exception(f"Error: caption file not found for poi: {caption_path}") with open(caption_path, 'r', encoding='utf-8') as f: json_data = json.load(f) if 'poi' not in json_data: raise Exception(f"Error: poi not found in caption file: {caption_path}") if self.poi not in json_data['poi']: raise Exception(f"Error: poi not found in caption file: {caption_path}") # poi has, x, y, width, height poi = json_data['poi'][self.poi] self.poi_x = int(poi['x']) self.poi_y = int(poi['y']) self.poi_width = int(poi['width']) self.poi_height = int(poi['height']) def setup_poi_bucket(self: 'FileItemDTO'): # we are using poi, so we need to calculate the bucket based on the poi resolution = self.dataset_config.resolution bucket_tolerance = self.dataset_config.bucket_tolerance initial_width = int(self.width * self.dataset_config.scale) initial_height = int(self.height * self.dataset_config.scale) poi_x = int(self.poi_x * self.dataset_config.scale) poi_y = int(self.poi_y * self.dataset_config.scale) poi_width = int(self.poi_width * self.dataset_config.scale) poi_height = int(self.poi_height * self.dataset_config.scale) # todo handle a poi that is smaller than resolution # determine new cropping crop_left = random.randint(0, poi_x) crop_right = random.randint(poi_x + poi_width, initial_width) crop_top = random.randint(0, poi_y) crop_bottom = random.randint(poi_y + poi_height, initial_height) new_width = crop_right - crop_left new_height = crop_bottom - crop_top bucket_resolution = get_bucket_for_image_size( new_width, new_height, resolution=resolution, divisibility=bucket_tolerance ) width_scale_factor = bucket_resolution["width"] / new_width height_scale_factor = bucket_resolution["height"] / new_height # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution max_scale_factor = max(width_scale_factor, height_scale_factor) self.scale_to_width = int(initial_width * max_scale_factor) self.scale_to_height = int(initial_height * max_scale_factor) self.crop_width = bucket_resolution['width'] self.crop_height = bucket_resolution['height'] self.crop_x = int(crop_left * max_scale_factor) self.crop_y = int(crop_top * max_scale_factor) class ArgBreakMixin: # just stops super calls form hitting object def __init__(self, *args, **kwargs): pass class LatentCachingFileItemDTOMixin: def __init__(self, *args, **kwargs): # if we have super, call it if hasattr(super(), '__init__'): super().__init__(*args, **kwargs) self._encoded_latent: Union[torch.Tensor, None] = None self._latent_path: Union[str, None] = None self.is_latent_cached = False self.is_caching_to_disk = False self.is_caching_to_memory = False self.latent_load_device = 'cpu' # sd1 or sdxl or others self.latent_space_version = 'sd1' # todo, increment this if we change the latent format to invalidate cache self.latent_version = 1 def get_latent_info_dict(self: 'FileItemDTO'): item = OrderedDict([ ("filename", os.path.basename(self.path)), ("scale_to_width", self.scale_to_width), ("scale_to_height", self.scale_to_height), ("crop_x", self.crop_x), ("crop_y", self.crop_y), ("crop_width", self.crop_width), ("crop_height", self.crop_height), ("latent_space_version", self.latent_space_version), ("latent_version", self.latent_version), ]) # when adding items, do it after so we dont change old latents if self.flip_x: item["flip_x"] = True if self.flip_y: item["flip_y"] = True return item def get_latent_path(self: 'FileItemDTO', recalculate=False): if self._latent_path is not None and not recalculate: return self._latent_path else: # we store latents in a folder in same path as image called _latent_cache img_dir = os.path.dirname(self.path) latent_dir = os.path.join(img_dir, '_latent_cache') hash_dict = self.get_latent_info_dict() filename_no_ext = os.path.splitext(os.path.basename(self.path))[0] # get base64 hash of md5 checksum of hash_dict hash_input = json.dumps(hash_dict, sort_keys=True).encode('utf-8') hash_str = base64.urlsafe_b64encode(hashlib.md5(hash_input).digest()).decode('ascii') hash_str = hash_str.replace('=', '') self._latent_path = os.path.join(latent_dir, f'{filename_no_ext}_{hash_str}.safetensors') return self._latent_path def cleanup_latent(self): if self._encoded_latent is not None: if not self.is_caching_to_memory: # we are caching on disk, don't save in memory self._encoded_latent = None else: # move it back to cpu self._encoded_latent = self._encoded_latent.to('cpu') def get_latent(self, device=None): if not self.is_latent_cached: return None if self._encoded_latent is None: # load it from disk state_dict = load_file( self.get_latent_path(), device=device if device is not None else self.latent_load_device ) self._encoded_latent = state_dict['latent'] return self._encoded_latent class LatentCachingMixin: def __init__(self: 'AiToolkitDataset', **kwargs): # if we have super, call it if hasattr(super(), '__init__'): super().__init__(**kwargs) self.latent_cache = {} def cache_latents_all_latents(self: 'AiToolkitDataset'): print(f"Caching latents for {self.dataset_path}") # cache all latents to disk to_disk = self.is_caching_latents_to_disk to_memory = self.is_caching_latents_to_memory if to_disk: print(" - Saving latents to disk") if to_memory: print(" - Keeping latents in memory") # move sd items to cpu except for vae self.sd.set_device_state_preset('cache_latents') # use tqdm to show progress for i, file_item in tqdm(enumerate(self.file_list), desc=f'Caching latents{" to disk" if to_disk else ""}'): # set latent space version if self.sd.is_xl: file_item.latent_space_version = 'sdxl' else: file_item.latent_space_version = 'sd1' file_item.is_caching_to_disk = to_disk file_item.is_caching_to_memory = to_memory file_item.latent_load_device = self.sd.device latent_path = file_item.get_latent_path(recalculate=True) # check if it is saved to disk already if os.path.exists(latent_path): if to_memory: # load it into memory state_dict = load_file(latent_path, device='cpu') file_item._encoded_latent = state_dict['latent'].to('cpu', dtype=self.sd.torch_dtype) else: # not saved to disk, calculate # load the image first file_item.load_and_process_image(self.transform) dtype = self.sd.torch_dtype device = self.sd.device_torch # add batch dimension imgs = file_item.tensor.unsqueeze(0).to(device, dtype=dtype) latent = self.sd.encode_images(imgs).squeeze(0) # save_latent if to_disk: state_dict = OrderedDict([ ('latent', latent.clone().detach().cpu()), ]) # metadata meta = get_meta_for_safetensors(file_item.get_latent_info_dict()) os.makedirs(os.path.dirname(latent_path), exist_ok=True) save_file(state_dict, latent_path, metadata=meta) if to_memory: # keep it in memory file_item._encoded_latent = latent.to('cpu', dtype=self.sd.torch_dtype) del imgs del latent del file_item.tensor flush(garbage_collect=False) file_item.is_latent_cached = True # flush every 100 # if i % 100 == 0: # flush() # restore device state self.sd.restore_device_state()