mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-02 17:19:48 +00:00
Added bucketting capabilities to dataloader. Finally have full planned capability. noice
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
@@ -11,15 +12,9 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.dataloader_mixins import CaptionMixin
|
||||
|
||||
BUCKET_STEPS = 64
|
||||
|
||||
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
|
||||
# make sure resolution is divisible by 8
|
||||
if resolution % 8 != 0:
|
||||
resolution = resolution - (resolution % 8)
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
@@ -291,32 +286,74 @@ class PairedImageDataset(Dataset):
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin):
|
||||
def __init__(self, dataset_config: 'DatasetConfig'):
|
||||
printed_messages = []
|
||||
|
||||
|
||||
def print_once(msg):
|
||||
global printed_messages
|
||||
if msg not in printed_messages:
|
||||
print(msg)
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItem:
|
||||
def __init__(self, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.width = kwargs.get('width', None)
|
||||
self.height = kwargs.get('height', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width = kwargs.get('scale_to_width', self.width)
|
||||
self.scale_to_height = kwargs.get('scale_to_height', self.height)
|
||||
# crop values are from scaled size
|
||||
self.crop_x = kwargs.get('crop_x', 0)
|
||||
self.crop_y = kwargs.get('crop_y', 0)
|
||||
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
file_list: List['FileItem'] = []
|
||||
|
||||
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
|
||||
super().__init__()
|
||||
self.dataset_config = dataset_config
|
||||
self.folder_path = dataset_config.folder_path
|
||||
self.caption_type = dataset_config.caption_type
|
||||
self.default_caption = dataset_config.default_caption
|
||||
self.random_scale = dataset_config.random_scale
|
||||
self.scale = dataset_config.scale
|
||||
self.batch_size = batch_size
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||
self.resolution = dataset_config.resolution
|
||||
|
||||
# get the file list
|
||||
self.file_list = [
|
||||
file_list = [
|
||||
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
||||
]
|
||||
|
||||
# this might take a while
|
||||
print(f" - Preprocessing image dimensions")
|
||||
new_file_list = []
|
||||
bad_count = 0
|
||||
for file in tqdm(self.file_list):
|
||||
img = Image.open(file)
|
||||
if int(min(img.size) * self.scale) >= self.resolution:
|
||||
new_file_list.append(file)
|
||||
for file in tqdm(file_list):
|
||||
try:
|
||||
w, h = image_utils.get_image_size(file)
|
||||
except image_utils.UnknownImageFormat:
|
||||
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
|
||||
f'This process is faster for png, jpeg')
|
||||
img = Image.open(file)
|
||||
h, w = img.size
|
||||
if int(min(h, w) * self.scale) >= self.resolution:
|
||||
self.file_list.append(
|
||||
FileItem(
|
||||
path=file,
|
||||
width=w,
|
||||
height=h,
|
||||
scale_to_width=int(w * self.scale),
|
||||
scale_to_height=int(h * self.scale),
|
||||
)
|
||||
)
|
||||
else:
|
||||
bad_count += 1
|
||||
|
||||
@@ -324,35 +361,57 @@ class AiToolkitDataset(Dataset, CaptionMixin):
|
||||
print(f" - Found {bad_count} images that are too small")
|
||||
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
return len(self.batch_indices)
|
||||
return len(self.file_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.file_list[index]
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
def _get_single_item(self, index):
|
||||
file_item = self.file_list[index]
|
||||
# todo make sure this matches
|
||||
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
|
||||
w, h = img.size
|
||||
if w > h and file_item.scale_to_width < file_item.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
|
||||
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
||||
min_img_size = min(img.size)
|
||||
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
if self.dataset_config.buckets:
|
||||
# todo allow scaling and cropping, will be hard to add
|
||||
# scale and crop based on file item
|
||||
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
|
||||
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||
if self.random_crop:
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
@@ -367,6 +426,31 @@ class AiToolkitDataset(Dataset, CaptionMixin):
|
||||
else:
|
||||
return img, dataset_config_dict
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.dataset_config.buckets:
|
||||
# we collate ourselves
|
||||
idx_list = self.batch_indices[item]
|
||||
tensor_list = []
|
||||
prompt_list = []
|
||||
dataset_config_dict_list = []
|
||||
for idx in idx_list:
|
||||
if self.caption_type is not None:
|
||||
img, prompt, dataset_config_dict = self._get_single_item(idx)
|
||||
prompt_list.append(prompt)
|
||||
dataset_config_dict_list.append(dataset_config_dict)
|
||||
else:
|
||||
img, dataset_config_dict = self._get_single_item(idx)
|
||||
dataset_config_dict_list.append(dataset_config_dict)
|
||||
tensor_list.append(img.unsqueeze(0))
|
||||
|
||||
if self.caption_type is not None:
|
||||
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
|
||||
else:
|
||||
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
|
||||
else:
|
||||
# Dataloader is batching
|
||||
return self._get_single_item(item)
|
||||
|
||||
|
||||
def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
# TODO do bucketing
|
||||
@@ -374,22 +458,43 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
|
||||
return None
|
||||
|
||||
datasets = []
|
||||
has_buckets = False
|
||||
for dataset_option in dataset_options:
|
||||
if isinstance(dataset_option, DatasetConfig):
|
||||
config = dataset_option
|
||||
else:
|
||||
config = DatasetConfig(**dataset_option)
|
||||
if config.type == 'image':
|
||||
dataset = AiToolkitDataset(config)
|
||||
dataset = AiToolkitDataset(config, batch_size=batch_size)
|
||||
datasets.append(dataset)
|
||||
if config.buckets:
|
||||
has_buckets = True
|
||||
else:
|
||||
raise ValueError(f"invalid dataset type: {config.type}")
|
||||
|
||||
concatenated_dataset = ConcatDataset(datasets)
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
if has_buckets:
|
||||
# make sure they all have buckets
|
||||
for dataset in datasets:
|
||||
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
# just return as is
|
||||
return batch
|
||||
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=None, # we batch in the dataloader
|
||||
drop_last=False,
|
||||
shuffle=True,
|
||||
collate_fn=custom_collate_fn, # Use the custom collate function
|
||||
num_workers=2
|
||||
)
|
||||
else:
|
||||
data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
return data_loader
|
||||
|
||||
Reference in New Issue
Block a user