mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-06 05:29:57 +00:00
267 lines
11 KiB
Python
267 lines
11 KiB
Python
import math
|
|
import os
|
|
import random
|
|
from typing import TYPE_CHECKING, List, Dict, Union
|
|
|
|
from toolkit.buckets import get_bucket_for_image_size
|
|
from toolkit.prompt_utils import inject_trigger_into_prompt
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
from PIL.ImageOps import exif_transpose
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.data_loader import AiToolkitDataset
|
|
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
|
|
|
|
|
# def get_associated_caption_from_img_path(img_path):
|
|
|
|
|
|
class CaptionMixin:
|
|
def get_caption_item(self: 'AiToolkitDataset', index):
|
|
if not hasattr(self, 'caption_type'):
|
|
raise Exception('caption_type not found on class instance')
|
|
if not hasattr(self, 'file_list'):
|
|
raise Exception('file_list not found on class instance')
|
|
img_path_or_tuple = self.file_list[index]
|
|
if isinstance(img_path_or_tuple, tuple):
|
|
img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
|
|
# check if either has a prompt file
|
|
path_no_ext = os.path.splitext(img_path)[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
if not os.path.exists(prompt_path):
|
|
img_path = img_path_or_tuple[1] if isinstance(img_path_or_tuple[1], str) else img_path_or_tuple[1].path
|
|
path_no_ext = os.path.splitext(img_path)[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
else:
|
|
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
|
|
# see if prompt file exists
|
|
path_no_ext = os.path.splitext(img_path)[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
|
|
if os.path.exists(prompt_path):
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt = f.read()
|
|
# remove any newlines
|
|
prompt = prompt.replace('\n', ', ')
|
|
# remove new lines for all operating systems
|
|
prompt = prompt.replace('\r', ', ')
|
|
prompt_split = prompt.split(',')
|
|
# remove empty strings
|
|
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
|
# join back together
|
|
prompt = ', '.join(prompt_split)
|
|
else:
|
|
prompt = ''
|
|
# get default_prompt if it exists on the class instance
|
|
if hasattr(self, 'default_prompt'):
|
|
prompt = self.default_prompt
|
|
if hasattr(self, 'default_caption'):
|
|
prompt = self.default_caption
|
|
return prompt
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.config_modules import DatasetConfig
|
|
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
|
|
|
|
|
class Bucket:
|
|
def __init__(self, width: int, height: int):
|
|
self.width = width
|
|
self.height = height
|
|
self.file_list_idx: List[int] = []
|
|
|
|
|
|
class BucketsMixin:
|
|
def __init__(self):
|
|
self.buckets: Dict[str, Bucket] = {}
|
|
self.batch_indices: List[List[int]] = []
|
|
|
|
def build_batch_indices(self: 'AiToolkitDataset'):
|
|
for key, bucket in self.buckets.items():
|
|
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
|
|
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
|
|
batch = bucket.file_list_idx[start_idx:end_idx]
|
|
self.batch_indices.append(batch)
|
|
|
|
def setup_buckets(self: 'AiToolkitDataset'):
|
|
if not hasattr(self, 'file_list'):
|
|
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
|
|
if not hasattr(self, 'dataset_config'):
|
|
raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
|
|
|
|
config: 'DatasetConfig' = self.dataset_config
|
|
resolution = config.resolution
|
|
bucket_tolerance = config.bucket_tolerance
|
|
file_list: List['FileItemDTO'] = self.file_list
|
|
|
|
total_pixels = resolution * resolution
|
|
|
|
# for file_item in enumerate(file_list):
|
|
for idx, file_item in enumerate(file_list):
|
|
width = file_item.crop_width
|
|
height = file_item.crop_height
|
|
|
|
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution)
|
|
|
|
# set the scaling height and with to match smallest size, and keep aspect ratio
|
|
if width > height:
|
|
file_item.scale_height = bucket_resolution["height"]
|
|
file_item.scale_width = int(width * (bucket_resolution["height"] / height))
|
|
else:
|
|
file_item.scale_width = bucket_resolution["width"]
|
|
file_item.scale_height = int(height * (bucket_resolution["width"] / width))
|
|
|
|
file_item.crop_height = bucket_resolution["height"]
|
|
file_item.crop_width = bucket_resolution["width"]
|
|
|
|
new_width = bucket_resolution["width"]
|
|
new_height = bucket_resolution["height"]
|
|
|
|
# check if bucket exists, if not, create it
|
|
bucket_key = f'{new_width}x{new_height}'
|
|
if bucket_key not in self.buckets:
|
|
self.buckets[bucket_key] = Bucket(new_width, new_height)
|
|
self.buckets[bucket_key].file_list_idx.append(idx)
|
|
|
|
# print the buckets
|
|
self.build_batch_indices()
|
|
name = f"{os.path.basename(self.dataset_path)} ({self.resolution})"
|
|
print(f'Bucket sizes for {self.dataset_path}:')
|
|
for key, bucket in self.buckets.items():
|
|
print(f'{key}: {len(bucket.file_list_idx)} files')
|
|
print(f'{len(self.buckets)} buckets made')
|
|
|
|
# file buckets made
|
|
|
|
|
|
class CaptionProcessingDTOMixin:
|
|
|
|
# todo allow for loading from sd-scripts style dict
|
|
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
|
if self.raw_caption is not None:
|
|
# we already loaded it
|
|
pass
|
|
elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]:
|
|
self.raw_caption = caption_dict[self.path]["caption"]
|
|
else:
|
|
# see if prompt file exists
|
|
path_no_ext = os.path.splitext(self.path)[0]
|
|
prompt_ext = self.dataset_config.caption_ext
|
|
prompt_path = f"{path_no_ext}.{prompt_ext}"
|
|
|
|
if os.path.exists(prompt_path):
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt = f.read()
|
|
# remove any newlines
|
|
prompt = prompt.replace('\n', ', ')
|
|
# remove new lines for all operating systems
|
|
prompt = prompt.replace('\r', ', ')
|
|
prompt_split = prompt.split(',')
|
|
# remove empty strings
|
|
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
|
# join back together
|
|
prompt = ', '.join(prompt_split)
|
|
else:
|
|
prompt = ''
|
|
if self.dataset_config.default_caption is not None:
|
|
prompt = self.dataset_config.default_caption
|
|
self.raw_caption = prompt
|
|
|
|
def get_caption(
|
|
self: 'FileItemDTO',
|
|
trigger=None,
|
|
to_replace_list=None,
|
|
add_if_not_present=False
|
|
):
|
|
raw_caption = self.raw_caption
|
|
if raw_caption is None:
|
|
raw_caption = ''
|
|
# handle dropout
|
|
if self.dataset_config.caption_dropout_rate > 0:
|
|
# get a random float form 0 to 1
|
|
rand = random.random()
|
|
if rand < self.dataset_config.caption_dropout_rate:
|
|
# drop the caption
|
|
return ''
|
|
|
|
# get tokens
|
|
token_list = raw_caption.split(',')
|
|
# trim whitespace
|
|
token_list = [x.strip() for x in token_list]
|
|
# remove empty strings
|
|
token_list = [x for x in token_list if x]
|
|
|
|
if self.dataset_config.shuffle_tokens:
|
|
random.shuffle(token_list)
|
|
|
|
# handle token dropout
|
|
if self.dataset_config.token_dropout_rate > 0:
|
|
new_token_list = []
|
|
for token in token_list:
|
|
# get a random float form 0 to 1
|
|
rand = random.random()
|
|
if rand > self.dataset_config.token_dropout_rate:
|
|
# keep the token
|
|
new_token_list.append(token)
|
|
token_list = new_token_list
|
|
|
|
# join back together
|
|
caption = ', '.join(token_list)
|
|
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
|
return caption
|
|
|
|
|
|
class ImageProcessingDTOMixin:
|
|
def load_and_process_image(
|
|
self: 'FileItemDTO',
|
|
transform: Union[None, transforms.Compose]
|
|
):
|
|
# todo make sure this matches
|
|
try:
|
|
img = Image.open(self.path).convert('RGB')
|
|
img = exif_transpose(img)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
print(f"Error loading image: {self.path}")
|
|
w, h = img.size
|
|
if w > h and self.scale_to_width < self.scale_to_height:
|
|
# throw error, they should match
|
|
raise ValueError(
|
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
|
elif h > w and self.scale_to_height < self.scale_to_width:
|
|
# throw error, they should match
|
|
raise ValueError(
|
|
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
|
|
|
if self.dataset_config.buckets:
|
|
# todo allow scaling and cropping, will be hard to add
|
|
# scale and crop based on file item
|
|
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
|
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
|
else:
|
|
# Downscale the source image first
|
|
img = img.resize(
|
|
(int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)),
|
|
Image.BICUBIC)
|
|
min_img_size = min(img.size)
|
|
if self.dataset_config.random_crop:
|
|
if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution:
|
|
if min_img_size < self.dataset_config.resolution:
|
|
print(
|
|
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}")
|
|
scale_size = self.dataset_config.resolution
|
|
else:
|
|
scale_size = random.randint(self.dataset_config.resolution, int(min_img_size))
|
|
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
|
img = transforms.RandomCrop(self.dataset_config.resolution)(img)
|
|
else:
|
|
img = transforms.CenterCrop(min_img_size)(img)
|
|
img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC)
|
|
|
|
if transform:
|
|
img = transform(img)
|
|
|
|
self.tensor = img
|