Hude rework to move the batch to a DTO to make it far more modular to the future ui

This commit is contained in:
Jaret Burkett
2023-08-29 10:22:19 -06:00
parent bd758ff203
commit 714854ee86
10 changed files with 286 additions and 232 deletions

View File

@@ -1,3 +1,4 @@
import json
import os
import random
from typing import List
@@ -13,10 +14,9 @@ from tqdm import tqdm
import albumentations as A
from toolkit import image_utils
from toolkit.config_modules import DatasetConfig
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
from toolkit.data_transfer_object.data_loader import FileItemDTO
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
class ImageDataset(Dataset, CaptionMixin):
@@ -29,7 +29,7 @@ class ImageDataset(Dataset, CaptionMixin):
self.include_prompt = self.get_config('include_prompt', False)
self.default_prompt = self.get_config('default_prompt', '')
if self.include_prompt:
self.caption_type = self.get_config('caption_type', 'txt')
self.caption_type = self.get_config('caption_ext', 'txt')
else:
self.caption_type = None
# we always random crop if random scale is enabled
@@ -288,24 +288,17 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
super().__init__()
self.dataset_config = dataset_config
self.folder_path = dataset_config.folder_path
self.caption_type = dataset_config.caption_type
folder_path = dataset_config.folder_path
self.dataset_path = dataset_config.dataset_path
if self.dataset_path is None:
self.dataset_path = folder_path
self.caption_type = dataset_config.caption_ext
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
self.scale = dataset_config.scale
@@ -313,147 +306,96 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
self.caption_dict = None
self.file_list: List['FileItemDTO'] = []
# get the file list
file_list = [
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
# check if dataset_path is a folder or json
if os.path.isdir(self.dataset_path):
file_list = [
os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
else:
# assume json
with open(self.dataset_path, 'r') as f:
self.caption_dict = json.load(f)
# keys are file paths
file_list = list(self.caption_dict.keys())
# this might take a while
print(f" - Preprocessing image dimensions")
bad_count = 0
for file in tqdm(file_list):
try:
w, h = image_utils.get_image_size(file)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = Image.open(file)
h, w = img.size
# TODO allow smaller images
if int(min(h, w) * self.scale) >= self.resolution:
self.file_list.append(
FileItemDTO(
path=file,
width=w,
height=h,
scale_to_width=int(w * self.scale),
scale_to_height=int(h * self.scale),
dataset_config=dataset_config
)
)
else:
file_item = FileItemDTO(
path=file,
dataset_config=dataset_config
)
if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution:
bad_count += 1
else:
self.file_list.append(file_item)
print(f" - Found {len(self.file_list)} images")
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
assert len(self.file_list) > 0, f"no images found in {self.dataset_path}"
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
self.setup_epoch()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def setup_epoch(self):
# TODO: set this up to redo cropping and everything else
# do not call for now
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
def __len__(self):
if self.dataset_config.buckets:
return len(self.batch_indices)
return len(self.file_list)
def _get_single_item(self, index):
def _get_single_item(self, index) -> 'FileItemDTO':
file_item = self.file_list[index]
# todo make sure this matches
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
w, h = img.size
if w > h and file_item.scale_to_width < file_item.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
else:
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
# todo convert it all
dataset_config_dict = {
"is_reg": 1 if self.dataset_config.is_reg else 0,
}
if self.caption_type is not None:
prompt = self.get_caption_item(index)
return img, prompt, dataset_config_dict
else:
return img, dataset_config_dict
file_item.load_and_process_image(self.transform)
file_item.load_caption(self.caption_dict)
return file_item
def __getitem__(self, item):
if self.dataset_config.buckets:
# for buckets we collate ourselves for now
# todo allow a scheduler to dynamically make buckets
# we collate ourselves
idx_list = self.batch_indices[item]
tensor_list = []
prompt_list = []
dataset_config_dict_list = []
for idx in idx_list:
if self.caption_type is not None:
img, prompt, dataset_config_dict = self._get_single_item(idx)
prompt_list.append(prompt)
dataset_config_dict_list.append(dataset_config_dict)
else:
img, dataset_config_dict = self._get_single_item(idx)
dataset_config_dict_list.append(dataset_config_dict)
tensor_list.append(img.unsqueeze(0))
if self.caption_type is not None:
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
else:
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
return [self._get_single_item(idx) for idx in idx_list]
else:
# Dataloader is batching
return self._get_single_item(item)
def get_dataloader_from_datasets(dataset_options, batch_size=1):
# TODO do bucketing
if dataset_options is None or len(dataset_options) == 0:
return None
datasets = []
has_buckets = False
dataset_config_list = []
# preprocess them all
for dataset_option in dataset_options:
if isinstance(dataset_option, DatasetConfig):
config = dataset_option
dataset_config_list.append(dataset_option)
else:
config = DatasetConfig(**dataset_option)
# preprocess raw data
split_configs = preprocess_dataset_raw_config([dataset_option])
for x in split_configs:
dataset_config_list.append(DatasetConfig(**x))
for config in dataset_config_list:
if config.type == 'image':
dataset = AiToolkitDataset(config, batch_size=batch_size)
datasets.append(dataset)
@@ -463,21 +405,28 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
raise ValueError(f"invalid dataset type: {config.type}")
concatenated_dataset = ConcatDataset(datasets)
# todo build scheduler that can get buckets from all datasets that match
# todo and evenly distribute reg images
def dto_collation(batch: List['FileItemDTO']):
# create DTO batch
batch = DataLoaderBatchDTO(
file_items=batch
)
return batch
if has_buckets:
# make sure they all have buckets
for dataset in datasets:
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
def custom_collate_fn(batch):
# just return as is
return batch
data_loader = DataLoader(
concatenated_dataset,
batch_size=None, # we batch in the dataloader
batch_size=None, # we batch in the datasets for now
drop_last=False,
shuffle=True,
collate_fn=custom_collate_fn, # Use the custom collate function
collate_fn=dto_collation, # Use the custom collate function
num_workers=2
)
else:
@@ -485,6 +434,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
num_workers=2,
collate_fn=dto_collation
)
return data_loader