From 714854ee86b55a4e6bd4344c0b33cccafc2e1c8b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 29 Aug 2023 10:22:19 -0600 Subject: [PATCH] Hude rework to move the batch to a DTO to make it far more modular to the future ui --- extensions_built_in/sd_trainer/SDTrainer.py | 6 + .../sd_trainer/config/train.example.yaml | 2 +- jobs/process/BaseSDTrainProcess.py | 44 ++-- jobs/process/TrainLoRAHack.py | 76 ------- jobs/process/__init__.py | 1 - testing/test_bucket_dataloader.py | 2 +- toolkit/config_modules.py | 34 ++- toolkit/data_loader.py | 194 +++++++----------- toolkit/data_transfer_object/data_loader.py | 74 +++++-- toolkit/dataloader_mixins.py | 85 +++++++- 10 files changed, 286 insertions(+), 232 deletions(-) delete mode 100644 jobs/process/TrainLoRAHack.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f9a9639e..77ee1236 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -35,6 +35,7 @@ class SDTrainer(BaseSDTrainProcess): def hook_train_loop(self, batch): dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + network_weight_list = batch.get_network_weight_list() self.optimizer.zero_grad() flush() @@ -53,6 +54,9 @@ class SDTrainer(BaseSDTrainProcess): else: network = BlankNetwork() + # set the weights + network.multiplier = network_weight_list + # activate network if it exits with network: with torch.set_grad_enabled(grad_on_text_encoder): @@ -114,5 +118,7 @@ class SDTrainer(BaseSDTrainProcess): loss_dict = OrderedDict( {'loss': loss.item()} ) + # reset network multiplier + network.multiplier = 1.0 return loss_dict diff --git a/extensions_built_in/sd_trainer/config/train.example.yaml b/extensions_built_in/sd_trainer/config/train.example.yaml index a563ee8c..793d5d55 100644 --- a/extensions_built_in/sd_trainer/config/train.example.yaml +++ b/extensions_built_in/sd_trainer/config/train.example.yaml @@ -19,7 +19,7 @@ config: max_step_saves_to_keep: 5 # only affects step counts datasets: - folder_path: "/path/to/dataset" - caption_type: "txt" + caption_ext: "txt" default_caption: "[trigger]" buckets: true resolution: 512 diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index df68c698..1d48a010 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -6,6 +6,7 @@ from typing import Union from torch.utils.data import DataLoader from toolkit.data_loader import get_dataloader_from_datasets +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer @@ -23,7 +24,7 @@ import torch from tqdm import tqdm from toolkit.config_modules import SaveConfig, LogingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig + GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config def flush(): @@ -67,6 +68,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.trigger_word = self.get_conf('trigger_word', None) raw_datasets = self.get_conf('datasets', None) + if raw_datasets is not None and len(raw_datasets) > 0: + raw_datasets = preprocess_dataset_raw_config(raw_datasets) self.datasets = None self.datasets_reg = None if raw_datasets is not None and len(raw_datasets) > 0: @@ -94,6 +97,12 @@ class BaseSDTrainProcess(BaseTrainProcess): if latest_save_path is not None: print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") self.model_config.name_or_path = latest_save_path + meta = load_metadata_from_safetensors(latest_save_path) + # if 'training_info' in Orderdict keys + if 'training_info' in meta and 'step' in meta['training_info']: + self.step_num = meta['training_info']['step'] + self.start_step = self.step_num + print(f"Found step {self.step_num} in metadata, starting from there") self.sd = StableDiffusion( device=self.device, @@ -307,16 +316,9 @@ class BaseSDTrainProcess(BaseTrainProcess): def process_general_training_batch(self, batch): with torch.no_grad(): - imgs, prompts, dataset_config = batch - - # convert the 0 or 1 for is reg to a bool list - if isinstance(dataset_config, list): - is_reg_list = [x.get('is_reg', 0) for x in dataset_config] - else: - is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])]) - if isinstance(is_reg_list, torch.Tensor): - is_reg_list = is_reg_list.numpy().tolist() - is_reg_list = [bool(x) for x in is_reg_list] + imgs = batch.tensor + prompts = batch.get_caption_list() + is_reg_list = batch.get_is_reg_list() conditioned_prompts = [] @@ -473,6 +475,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # resume state from embedding self.step_num = self.embedding.step + self.start_step = self.step_num # set trainable params params = self.embedding.get_trainable_params() @@ -556,13 +559,18 @@ class BaseSDTrainProcess(BaseTrainProcess): with torch.no_grad(): # if is even step and we have a reg dataset, use that # todo improve this logic to send one of each through if we can buckets and batch size might be an issue - if step % 2 == 0 and dataloader_reg is not None: + is_reg_step = False + is_save_step = self.save_config.save_every and self.step_num % self.save_config.save_every == 0 + is_sample_step = self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0 + # don't do a reg step on sample or save steps as we dont want to normalize on those + if step % 2 == 0 and dataloader_reg is not None and not is_save_step and not is_sample_step: try: batch = next(dataloader_iterator_reg) except StopIteration: # hit the end of an epoch, reset dataloader_iterator_reg = iter(dataloader_reg) batch = next(dataloader_iterator_reg) + is_reg_step = True elif dataloader is not None: try: batch = next(dataloader_iterator) @@ -601,11 +609,11 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.step_num != self.start_step: # pause progress bar self.progress_bar.unpause() # makes it so doesn't track time - if self.sample_config.sample_every and self.step_num % self.sample_config.sample_every == 0: + if is_sample_step: # print above the progress bar self.sample(self.step_num) - if self.save_config.save_every and self.step_num % self.save_config.save_every == 0: + if is_save_step: # print above the progress bar self.print(f"Saving at step {self.step_num}") self.save(self.step_num) @@ -623,10 +631,14 @@ class BaseSDTrainProcess(BaseTrainProcess): # end of step self.step_num = step - # apply network normalizer if we are using it - if self.network is not None and self.network.is_normalizing: + # apply network normalizer if we are using it, not on regularization steps + if self.network is not None and self.network.is_normalizing and not is_reg_step: self.network.apply_stored_normalizer() + # if the batch is a DataLoaderBatchDTO, then we need to clean it up + if isinstance(batch, DataLoaderBatchDTO): + batch.cleanup() + self.sample(self.step_num + 1) print("") self.save() diff --git a/jobs/process/TrainLoRAHack.py b/jobs/process/TrainLoRAHack.py deleted file mode 100644 index 2a5a6539..00000000 --- a/jobs/process/TrainLoRAHack.py +++ /dev/null @@ -1,76 +0,0 @@ -# ref: -# - https://github.com/p1atdev/LECO/blob/main/train_lora.py -import time -from collections import OrderedDict -import os - -from toolkit.config_modules import SliderConfig -from toolkit.paths import REPOS_ROOT -import sys - -sys.path.append(REPOS_ROOT) -sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from toolkit.train_tools import get_torch_dtype, apply_noise_offset -import gc - -import torch -from leco import train_util, model_util -from leco.prompt_util import PromptEmbedsCache -from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion - - -def flush(): - torch.cuda.empty_cache() - gc.collect() - - -class LoRAHack: - def __init__(self, **kwargs): - self.type = kwargs.get('type', 'suppression') - - -class TrainLoRAHack(BaseSDTrainProcess): - def __init__(self, process_id: int, job, config: OrderedDict): - super().__init__(process_id, job, config) - self.hack_config = LoRAHack(**self.get_conf('hack', {})) - - def hook_before_train_loop(self): - # we don't need text encoder so move it to cpu - self.sd.text_encoder.to("cpu") - flush() - # end hook_before_train_loop - - if self.hack_config.type == 'suppression': - # set all params to self.current_suppression - params = self.network.parameters() - for param in params: - # get random noise for each param - noise = torch.randn_like(param) - 0.5 - # apply noise to param - param.data = noise * 0.001 - - - def supress_loop(self): - dtype = get_torch_dtype(self.train_config.dtype) - - - loss_dict = OrderedDict( - {'sup': 0.0} - ) - # increase noise - for param in self.network.parameters(): - # get random noise for each param - noise = torch.randn_like(param) - 0.5 - # apply noise to param - param.data = param.data + noise * 0.001 - - - - return loss_dict - - def hook_train_loop(self, batch): - if self.hack_config.type == 'suppression': - return self.supress_loop() - else: - raise NotImplementedError(f'unknown hack type: {self.hack_config.type}') - # end hook_train_loop diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 766c6b11..387be088 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -7,7 +7,6 @@ from .TrainVAEProcess import TrainVAEProcess from .BaseMergeProcess import BaseMergeProcess from .TrainSliderProcess import TrainSliderProcess from .TrainSliderProcessOld import TrainSliderProcessOld -from .TrainLoRAHack import TrainLoRAHack from .TrainSDRescaleProcess import TrainSDRescaleProcess from .ModRescaleLoraProcess import ModRescaleLoraProcess from .GenerateProcess import GenerateProcess diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 85a21ef7..6c1eec7b 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -22,7 +22,7 @@ batch_size = 4 dataset_config = DatasetConfig( folder_path=dataset_folder, resolution=resolution, - caption_type='txt', + caption_ext='txt', default_caption='default', buckets=True, bucket_tolerance=bucket_tolerance, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 7f7744d8..a06b6a94 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -5,6 +5,7 @@ import random ImgExt = Literal['jpg', 'png', 'webp'] + class SaveConfig: def __init__(self, **kwargs): self.save_every: int = kwargs.get('save_every', 1000) @@ -167,9 +168,13 @@ class DatasetConfig: def __init__(self, **kwargs): self.type = kwargs.get('type', 'image') # sd, slider, reference + # will be legacy self.folder_path: str = kwargs.get('folder_path', None) + # can be json or folder path + self.dataset_path: str = kwargs.get('dataset_path', None) + self.default_caption: str = kwargs.get('default_caption', None) - self.caption_type: str = kwargs.get('caption_type', None) + self.caption_ext: str = kwargs.get('caption_ext', None) self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) @@ -182,6 +187,33 @@ class DatasetConfig: self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + # legacy compatability + legacy_caption_type = kwargs.get('caption_type', None) + if legacy_caption_type: + self.caption_ext = legacy_caption_type + self.caption_type = self.caption_ext + + +def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: + """ + This just splits up the datasets by resolutions so you dont have to do it manually + :param raw_config: + :return: + """ + # split up datasets by resolutions + new_config = [] + for dataset in raw_config: + resolution = dataset.get('resolution', 512) + if isinstance(resolution, list): + resolution_list = resolution + else: + resolution_list = [resolution] + for res in resolution_list: + dataset_copy = dataset.copy() + dataset_copy['resolution'] = res + new_config.append(dataset_copy) + return new_config + class GenerateImageConfig: def __init__( diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 5eb1c371..c29b3510 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -1,3 +1,4 @@ +import json import os import random from typing import List @@ -13,10 +14,9 @@ from tqdm import tqdm import albumentations as A from toolkit import image_utils -from toolkit.config_modules import DatasetConfig +from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin -from toolkit.data_transfer_object.data_loader import FileItemDTO - +from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO class ImageDataset(Dataset, CaptionMixin): @@ -29,7 +29,7 @@ class ImageDataset(Dataset, CaptionMixin): self.include_prompt = self.get_config('include_prompt', False) self.default_prompt = self.get_config('default_prompt', '') if self.include_prompt: - self.caption_type = self.get_config('caption_type', 'txt') + self.caption_type = self.get_config('caption_ext', 'txt') else: self.caption_type = None # we always random crop if random scale is enabled @@ -288,24 +288,17 @@ class PairedImageDataset(Dataset): return img, prompt, (self.neg_weight, self.pos_weight) -printed_messages = [] - - -def print_once(msg): - global printed_messages - if msg not in printed_messages: - print(msg) - printed_messages.append(msg) - - - class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): def __init__(self, dataset_config: 'DatasetConfig', batch_size=1): super().__init__() self.dataset_config = dataset_config - self.folder_path = dataset_config.folder_path - self.caption_type = dataset_config.caption_type + folder_path = dataset_config.folder_path + self.dataset_path = dataset_config.dataset_path + if self.dataset_path is None: + self.dataset_path = folder_path + + self.caption_type = dataset_config.caption_ext self.default_caption = dataset_config.default_caption self.random_scale = dataset_config.random_scale self.scale = dataset_config.scale @@ -313,147 +306,96 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): # we always random crop if random scale is enabled self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop self.resolution = dataset_config.resolution + self.caption_dict = None self.file_list: List['FileItemDTO'] = [] - # get the file list - file_list = [ - os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if - file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')) - ] + # check if dataset_path is a folder or json + if os.path.isdir(self.dataset_path): + file_list = [ + os.path.join(self.dataset_path, file) for file in os.listdir(self.dataset_path) if + file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')) + ] + else: + # assume json + with open(self.dataset_path, 'r') as f: + self.caption_dict = json.load(f) + # keys are file paths + file_list = list(self.caption_dict.keys()) # this might take a while print(f" - Preprocessing image dimensions") bad_count = 0 for file in tqdm(file_list): - try: - w, h = image_utils.get_image_size(file) - except image_utils.UnknownImageFormat: - print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ - f'This process is faster for png, jpeg') - img = Image.open(file) - h, w = img.size - # TODO allow smaller images - if int(min(h, w) * self.scale) >= self.resolution: - self.file_list.append( - FileItemDTO( - path=file, - width=w, - height=h, - scale_to_width=int(w * self.scale), - scale_to_height=int(h * self.scale), - dataset_config=dataset_config - ) - ) - else: + file_item = FileItemDTO( + path=file, + dataset_config=dataset_config + ) + if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution: bad_count += 1 + else: + self.file_list.append(file_item) print(f" - Found {len(self.file_list)} images") print(f" - Found {bad_count} images that are too small") - assert len(self.file_list) > 0, f"no images found in {self.folder_path}" + assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" - if self.dataset_config.buckets: - # setup buckets - self.setup_buckets() + self.setup_epoch() self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] ]) + def setup_epoch(self): + # TODO: set this up to redo cropping and everything else + # do not call for now + if self.dataset_config.buckets: + # setup buckets + self.setup_buckets() + def __len__(self): if self.dataset_config.buckets: return len(self.batch_indices) return len(self.file_list) - def _get_single_item(self, index): + def _get_single_item(self, index) -> 'FileItemDTO': file_item = self.file_list[index] - # todo make sure this matches - img = exif_transpose(Image.open(file_item.path)).convert('RGB') - w, h = img.size - if w > h and file_item.scale_to_width < file_item.scale_to_height: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}") - elif h > w and file_item.scale_to_height < file_item.scale_to_width: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}") - - # Downscale the source image first - img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) - min_img_size = min(img.size) - - if self.dataset_config.buckets: - # todo allow scaling and cropping, will be hard to add - # scale and crop based on file item - img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC) - img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img) - else: - if self.random_crop: - if self.random_scale and min_img_size > self.resolution: - if min_img_size < self.resolution: - print( - f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}") - scale_size = self.resolution - else: - scale_size = random.randint(self.resolution, int(min_img_size)) - img = img.resize((scale_size, scale_size), Image.BICUBIC) - img = transforms.RandomCrop(self.resolution)(img) - else: - img = transforms.CenterCrop(min_img_size)(img) - img = img.resize((self.resolution, self.resolution), Image.BICUBIC) - - img = self.transform(img) - - # todo convert it all - dataset_config_dict = { - "is_reg": 1 if self.dataset_config.is_reg else 0, - } - - if self.caption_type is not None: - prompt = self.get_caption_item(index) - return img, prompt, dataset_config_dict - else: - return img, dataset_config_dict + file_item.load_and_process_image(self.transform) + file_item.load_caption(self.caption_dict) + return file_item def __getitem__(self, item): if self.dataset_config.buckets: + # for buckets we collate ourselves for now + # todo allow a scheduler to dynamically make buckets # we collate ourselves idx_list = self.batch_indices[item] - tensor_list = [] - prompt_list = [] - dataset_config_dict_list = [] - for idx in idx_list: - if self.caption_type is not None: - img, prompt, dataset_config_dict = self._get_single_item(idx) - prompt_list.append(prompt) - dataset_config_dict_list.append(dataset_config_dict) - else: - img, dataset_config_dict = self._get_single_item(idx) - dataset_config_dict_list.append(dataset_config_dict) - tensor_list.append(img.unsqueeze(0)) - - if self.caption_type is not None: - return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list - else: - return torch.cat(tensor_list, dim=0), dataset_config_dict_list + return [self._get_single_item(idx) for idx in idx_list] else: # Dataloader is batching return self._get_single_item(item) def get_dataloader_from_datasets(dataset_options, batch_size=1): - # TODO do bucketing if dataset_options is None or len(dataset_options) == 0: return None datasets = [] has_buckets = False + + dataset_config_list = [] + # preprocess them all for dataset_option in dataset_options: if isinstance(dataset_option, DatasetConfig): - config = dataset_option + dataset_config_list.append(dataset_option) else: - config = DatasetConfig(**dataset_option) + # preprocess raw data + split_configs = preprocess_dataset_raw_config([dataset_option]) + for x in split_configs: + dataset_config_list.append(DatasetConfig(**x)) + + for config in dataset_config_list: + if config.type == 'image': dataset = AiToolkitDataset(config, batch_size=batch_size) datasets.append(dataset) @@ -463,21 +405,28 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1): raise ValueError(f"invalid dataset type: {config.type}") concatenated_dataset = ConcatDataset(datasets) + + # todo build scheduler that can get buckets from all datasets that match + # todo and evenly distribute reg images + + def dto_collation(batch: List['FileItemDTO']): + # create DTO batch + batch = DataLoaderBatchDTO( + file_items=batch + ) + return batch + if has_buckets: # make sure they all have buckets for dataset in datasets: assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none" - def custom_collate_fn(batch): - # just return as is - return batch - data_loader = DataLoader( concatenated_dataset, - batch_size=None, # we batch in the dataloader + batch_size=None, # we batch in the datasets for now drop_last=False, shuffle=True, - collate_fn=custom_collate_fn, # Use the custom collate function + collate_fn=dto_collation, # Use the custom collate function num_workers=2 ) else: @@ -485,6 +434,7 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1): concatenated_dataset, batch_size=batch_size, shuffle=True, - num_workers=2 + num_workers=2, + collate_fn=dto_collation ) return data_loader diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index f2c9a509..32a91f48 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -1,36 +1,84 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Union import torch import random -from toolkit.dataloader_mixins import CaptionProcessingDTOMixin +from PIL import Image +from PIL.ImageOps import exif_transpose + +from toolkit import image_utils +from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig +printed_messages = [] -class FileItemDTO(CaptionProcessingDTOMixin): + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + + +class FileItemDTO(CaptionProcessingDTOMixin, ImageProcessingDTOMixin): def __init__(self, **kwargs): self.path = kwargs.get('path', None) - self.caption_path: str = kwargs.get('caption_path', None) + self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + # process width and height + try: + w, h = image_utils.get_image_size(self.path) + except image_utils.UnknownImageFormat: + print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + f'This process is faster for png, jpeg') + img = exif_transpose(Image.open(self.path)) + h, w = img.size + self.width: int = w + self.height: int = h + + # self.caption_path: str = kwargs.get('caption_path', None) self.raw_caption: str = kwargs.get('raw_caption', None) - self.width: int = kwargs.get('width', None) - self.height: int = kwargs.get('height', None) # we scale first, then crop - self.scale_to_width: int = kwargs.get('scale_to_width', self.width) - self.scale_to_height: int = kwargs.get('scale_to_height', self.height) + self.scale_to_width: int = kwargs.get('scale_to_width', int(self.width * self.dataset_config.scale)) + self.scale_to_height: int = kwargs.get('scale_to_height', int(self.height * self.dataset_config.scale)) # crop values are from scaled size self.crop_x: int = kwargs.get('crop_x', 0) self.crop_y: int = kwargs.get('crop_y', 0) self.crop_width: int = kwargs.get('crop_width', self.scale_to_width) self.crop_height: int = kwargs.get('crop_height', self.scale_to_height) - # process config - self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.network_weight: float = self.dataset_config.network_weight + self.is_reg = self.dataset_config.is_reg + self.tensor: Union[torch.Tensor, None] = None - self.network_network_weight: float = self.dataset_config.network_weight + def cleanup(self): + self.tensor = None class DataLoaderBatchDTO: def __init__(self, **kwargs): - self.file_item: 'FileItemDTO' = kwargs.get('file_item', None) - self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) + self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) + + def get_is_reg_list(self): + return [x.is_reg for x in self.file_items] + + def get_network_weight_list(self): + return [x.network_weight for x in self.file_items] + + def get_caption_list( + self, + trigger=None, + to_replace_list=None, + add_if_not_present=True + ): + return [x.get_caption( + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present + ) for x in self.file_items] + + def cleanup(self): + self.tensor = None + for file_item in self.file_items: + file_item.cleanup() diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index bc4be6a7..74567946 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1,8 +1,11 @@ import os import random -from typing import TYPE_CHECKING, List, Dict +from typing import TYPE_CHECKING, List, Dict, Union from toolkit.prompt_utils import inject_trigger_into_prompt +from torchvision import transforms +from PIL import Image +from PIL.ImageOps import exif_transpose if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset @@ -159,6 +162,38 @@ class BucketsMixin: class CaptionProcessingDTOMixin: + + # todo allow for loading from sd-scripts style dict + def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): + if self.raw_caption is not None: + # we already loaded it + pass + elif caption_dict is not None and self.path in caption_dict and "caption" in caption_dict[self.path]: + self.raw_caption = caption_dict[self.path]["caption"] + else: + # see if prompt file exists + path_no_ext = os.path.splitext(self.path)[0] + prompt_ext = self.dataset_config.caption_ext + prompt_path = f"{path_no_ext}.{prompt_ext}" + + if os.path.exists(prompt_path): + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() + # remove any newlines + prompt = prompt.replace('\n', ', ') + # remove new lines for all operating systems + prompt = prompt.replace('\r', ', ') + prompt_split = prompt.split(',') + # remove empty strings + prompt_split = [p.strip() for p in prompt_split if p.strip()] + # join back together + prompt = ', '.join(prompt_split) + else: + prompt = '' + if self.dataset_config.default_caption is not None: + prompt = self.dataset_config.default_caption + self.raw_caption = prompt + def get_caption( self: 'FileItemDTO', trigger=None, @@ -201,3 +236,51 @@ class CaptionProcessingDTOMixin: caption = ', '.join(token_list) caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) return caption + + +class ImageProcessingDTOMixin: + def load_and_process_image( + self: 'FileItemDTO', + transform: Union[None, transforms.Compose] + ): + # todo make sure this matches + img = exif_transpose(Image.open(self.path)).convert('RGB') + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.dataset_config.buckets: + # todo allow scaling and cropping, will be hard to add + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + else: + # Downscale the source image first + img = img.resize( + (int(img.size[0] * self.dataset_config.scale), int(img.size[1] * self.dataset_config.scale)), + Image.BICUBIC) + min_img_size = min(img.size) + if self.dataset_config.random_crop: + if self.dataset_config.random_scale and min_img_size > self.dataset_config.resolution: + if min_img_size < self.dataset_config.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.dataset_config.resolution}, image file={self.path}") + scale_size = self.dataset_config.resolution + else: + scale_size = random.randint(self.dataset_config.resolution, int(min_img_size)) + img = img.resize((scale_size, scale_size), Image.BICUBIC) + img = transforms.RandomCrop(self.dataset_config.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.dataset_config.resolution, self.dataset_config.resolution), Image.BICUBIC) + + if transform: + img = transform(img) + + self.tensor = img