diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 77ee1236..25ad1b6c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -60,12 +60,7 @@ class SDTrainer(BaseSDTrainProcess): # activate network if it exits with network: with torch.set_grad_enabled(grad_on_text_encoder): - embedding_list = [] - # embed the prompts - for prompt in conditioned_prompts: - embedding = self.sd.encode_prompt(prompt).to(self.device_torch, dtype=dtype) - embedding_list.append(embedding) - conditional_embeds = concat_prompt_embeds(embedding_list) + conditional_embeds = self.sd.encode_prompt(conditioned_prompts).to(self.device_torch, dtype=dtype) if not grad_on_text_encoder: # detach the embeddings conditional_embeds = conditional_embeds.detach() diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 6c1eec7b..3d8419b6 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -1,10 +1,21 @@ -from torch.utils.data import ConcatDataset, DataLoader -from tqdm import tqdm -# make sure we can import from the toolkit +import time + +import numpy as np +import torch +from torchvision import transforms import sys import os +import cv2 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from toolkit.paths import SD_SCRIPTS_ROOT + +from toolkit.image_utils import show_img + +sys.path.append(SD_SCRIPTS_ROOT) + +from library.model_util import load_vae +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets from toolkit.config_modules import DatasetConfig import argparse @@ -12,6 +23,7 @@ import argparse parser = argparse.ArgumentParser() parser.add_argument('dataset_folder', type=str, default='input') + args = parser.parse_args() dataset_folder = args.dataset_folder @@ -30,8 +42,29 @@ dataset_config = DatasetConfig( dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) + # run through an epoch ang check sizes for batch in dataloader: - print(list(batch[0].shape)) + batch: 'DataLoaderBatchDTO' + img_batch = batch.tensor + + chunks = torch.chunk(img_batch, batch_size, dim=0) + # put them so they are size by side + big_img = torch.cat(chunks, dim=3) + big_img = big_img.squeeze(0) + + min_val = big_img.min() + max_val = big_img.max() + + big_img = (big_img / 2 + 0.5).clamp(0, 1) + + # convert to image + img = transforms.ToPILImage()(big_img) + + show_img(img) + + time.sleep(1.0) + +cv2.destroyAllWindows() print('done') diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index c29b3510..90c8171b 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -326,14 +326,19 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): print(f" - Preprocessing image dimensions") bad_count = 0 for file in tqdm(file_list): - file_item = FileItemDTO( - path=file, - dataset_config=dataset_config - ) - if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution: + try: + file_item = FileItemDTO( + path=file, + dataset_config=dataset_config + ) + if file_item.scale_to_width < self.resolution or file_item.scale_to_height < self.resolution: + bad_count += 1 + else: + self.file_list.append(file_item) + except Exception as e: + print(f"Error processing image: {file}") + print(e) bad_count += 1 - else: - self.file_list.append(file_item) print(f" - Found {len(self.file_list)} images") print(f" - Found {bad_count} images that are too small") @@ -376,7 +381,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): return self._get_single_item(item) -def get_dataloader_from_datasets(dataset_options, batch_size=1): +def get_dataloader_from_datasets(dataset_options, batch_size=1) -> DataLoader: if dataset_options is None or len(dataset_options) == 0: return None diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index c25cfb14..a09ed5d8 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1,3 +1,4 @@ +import math import os import random from typing import TYPE_CHECKING, List, Dict, Union @@ -94,56 +95,56 @@ class BucketsMixin: bucket_tolerance = config.bucket_tolerance file_list: List['FileItemDTO'] = self.file_list - # make sure out resolution is divisible by bucket_tolerance - if resolution % bucket_tolerance != 0: - # reduce it to the nearest divisible number - resolution = resolution - (resolution % bucket_tolerance) + total_pixels = resolution * resolution # for file_item in enumerate(file_list): for idx, file_item in enumerate(file_list): width = file_item.crop_width height = file_item.crop_height - # determine new size, smallest dimension should be equal to resolution - # the other dimension should be the same ratio it is now (bigger) - new_width = resolution - new_height = resolution - if width > height: - # scale width to match new resolution, - new_width = int(width * (resolution / height)) - file_item.crop_width = new_width - file_item.scale_to_width = new_width - file_item.crop_height = resolution - file_item.scale_to_height = resolution - # make sure new_width is divisible by bucket_tolerance + # determine new resolution to have the same number of pixels + current_pixels = width * height + if current_pixels == total_pixels: + # no change + continue + + aspect_ratio = width / height + new_height = int(math.sqrt(total_pixels / aspect_ratio)) + new_width = int(aspect_ratio * new_height) + + # increase smallest one to be divisible by bucket_tolerance and increase the other to match + if new_width < new_height: + # increase width if new_width % bucket_tolerance != 0: - # reduce it to the nearest divisible number - reduction = new_width % bucket_tolerance - file_item.crop_width = new_width - reduction - new_width = file_item.crop_width - # adjust the new x position so we evenly crop - file_item.crop_x = int(file_item.crop_x + (reduction / 2)) - elif height > width: - # scale height to match new resolution - new_height = int(height * (resolution / width)) - file_item.crop_height = new_height - file_item.scale_to_height = new_height - file_item.scale_to_width = resolution - file_item.crop_width = resolution - # make sure new_height is divisible by bucket_tolerance - if new_height % bucket_tolerance != 0: - # reduce it to the nearest divisible number - reduction = new_height % bucket_tolerance - file_item.crop_height = new_height - reduction - new_height = file_item.crop_height - # adjust the new x position so we evenly crop - file_item.crop_y = int(file_item.crop_y + (reduction / 2)) + crop_amount = new_width % bucket_tolerance + new_width = new_width + (bucket_tolerance - crop_amount) + new_height = int(new_width / aspect_ratio) else: - # square image - file_item.crop_height = resolution - file_item.scale_to_height = resolution - file_item.scale_to_width = resolution - file_item.crop_width = resolution + # increase height + if new_height % bucket_tolerance != 0: + crop_amount = new_height % bucket_tolerance + new_height = new_height + (bucket_tolerance - crop_amount) + new_width = int(aspect_ratio * new_height) + + # Ensure that the total number of pixels remains the same. + # assert new_width * new_height == total_pixels + + file_item.scale_to_width = new_width + file_item.scale_to_height = new_height + file_item.crop_width = new_width + file_item.crop_height = new_height + # make sure it is divisible by bucket_tolerance, decrease if not + if new_width % bucket_tolerance != 0: + crop_amount = new_width % bucket_tolerance + file_item.crop_width = new_width - crop_amount + else: + file_item.crop_width = new_width + + if new_height % bucket_tolerance != 0: + crop_amount = new_height % bucket_tolerance + file_item.crop_height = new_height - crop_amount + else: + file_item.crop_height = new_height # check if bucket exists, if not, create it bucket_key = f'{new_width}x{new_height}' diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py index cb74ed4f..0625a030 100644 --- a/toolkit/image_utils.py +++ b/toolkit/image_utils.py @@ -1,10 +1,14 @@ # ref https://github.com/scardine/image_size/blob/master/get_image_size.py +import atexit import collections import json import os import io import struct +import cv2 +import numpy as np + FILE_UNKNOWN = "Sorry, don't know how to get size for this file." @@ -112,7 +116,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None): width = int(w) height = int(h) elif ((size >= 24) and data.startswith(b'\211PNG\r\n\032\n') - and (data[12:16] == b'IHDR')): + and (data[12:16] == b'IHDR')): # PNGs imgtype = PNG w, h = struct.unpack(">LL", data[16:24]) @@ -190,7 +194,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None): 9: (4, boChar + "l"), # SLONG 10: (8, boChar + "ll"), # SRATIONAL 11: (4, boChar + "f"), # FLOAT - 12: (8, boChar + "d") # DOUBLE + 12: (8, boChar + "d") # DOUBLE } ifdOffset = struct.unpack(boChar + "L", data[4:8])[0] try: @@ -206,7 +210,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None): input.seek(entryOffset) tag = input.read(2) tag = struct.unpack(boChar + "H", tag)[0] - if(tag == 256 or tag == 257): + if (tag == 256 or tag == 257): # if type indicates that value fits into 4 bytes, value # offset is not an offset but value itself type = input.read(2) @@ -229,7 +233,7 @@ def get_image_metadata_from_bytesio(input, size, file_path=None): except Exception as e: raise UnknownImageFormat(str(e)) elif size >= 2: - # see http://en.wikipedia.org/wiki/ICO_(file_format) + # see http://en.wikipedia.org/wiki/ICO_(file_format) imgtype = 'ICO' input.seek(0) reserved = input.read(2) @@ -350,13 +354,13 @@ def main(argv=None): prs.add_option('-v', '--verbose', dest='verbose', - action='store_true',) + action='store_true', ) prs.add_option('-q', '--quiet', dest='quiet', - action='store_true',) + action='store_true', ) prs.add_option('-t', '--test', dest='run_tests', - action='store_true',) + action='store_true', ) argv = list(argv) if argv is not None else sys.argv[1:] (opts, args) = prs.parse_args(args=argv) @@ -417,6 +421,36 @@ def main(argv=None): return EX_OK +is_window_shown = False + + +def show_img(img): + global is_window_shown + + img = np.clip(img, 0, 255).astype(np.uint8) + cv2.imshow('AI Toolkit', img[:, :, ::-1]) + k = cv2.waitKey(10) & 0xFF + if k == 27: # Esc key to stop + print('\nESC pressed, stopping') + raise KeyboardInterrupt + # show again to initialize the window if first + if not is_window_shown: + cv2.imshow('AI Toolkit', img[:, :, ::-1]) + k = cv2.waitKey(10) & 0xFF + if k == 27: # Esc key to stop + print('\nESC pressed, stopping') + raise KeyboardInterrupt + is_window_shown = True + + +def on_exit(): + if is_window_shown: + cv2.destroyAllWindows() + + +atexit.register(on_exit) + if __name__ == "__main__": import sys - sys.exit(main(argv=sys.argv[1:])) \ No newline at end of file + + sys.exit(main(argv=sys.argv[1:])) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6d1650e9..cbfc48c2 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -711,7 +711,7 @@ class StableDiffusion: raise ValueError(f"Unknown weight name: {name}") - def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=True): + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): return inject_trigger_into_prompt( prompt, trigger=trigger,