From 58dffd43a8f2e0aaff494b4ead71fdd9f99867a8 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 15 Jul 2024 19:07:41 -0600 Subject: [PATCH] Added caching to image sizes so we dont do it every time. --- jobs/process/BaseSDTrainProcess.py | 18 ++++++++----- testing/merge_in_text_encoder_adapter.py | 14 +++++------ testing/test_bucket_dataloader.py | 6 ++--- toolkit/custom_adapter.py | 20 ++++++++++++++- toolkit/data_loader.py | 13 ++++++++++ toolkit/data_transfer_object/data_loader.py | 25 +++++++++++------- toolkit/models/te_adapter.py | 28 +++++++++++++++------ 7 files changed, 90 insertions(+), 34 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b12f7e7c..68b6512f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -714,6 +714,17 @@ class BaseSDTrainProcess(BaseTrainProcess): sigma = sigma.unsqueeze(-1) return sigma + def get_noise(self, latents, batch_size, dtype=torch.float32): + # get noise + noise = self.sd.get_latent_noise( + height=latents.shape[2], + width=latents.shape[3], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + return noise + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): with self.timer('prepare_prompt'): @@ -927,12 +938,7 @@ class BaseSDTrainProcess(BaseTrainProcess): timesteps = torch.stack(timesteps, dim=0) # get noise - noise = self.sd.get_latent_noise( - height=latents.shape[2], - width=latents.shape[3], - batch_size=batch_size, - noise_offset=self.train_config.noise_offset, - ).to(self.device_torch, dtype=dtype) + noise = self.get_noise(latents, batch_size, dtype=dtype) # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents # this will negate any noise offsets diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py index 08d5c02e..d1a2983c 100644 --- a/testing/merge_in_text_encoder_adapter.py +++ b/testing/merge_in_text_encoder_adapter.py @@ -11,10 +11,10 @@ import json # te_path = "google/flan-t5-xl" # te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" # output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" -model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS" -te_path = "google/flan-t5-large" -te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors" -output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw" +model_path = "/home/jaret/Dev/models/hf/objective-reality-16ch" +te_path = "google/flan-t5-xl" +te_aug_path = "/mnt/Train2/out/ip_adapter/t5xl-sd15-16ch_v1/t5xl-sd15-16ch_v1_000115000.safetensors" +output_path = "/home/jaret/Dev/models/hf/t5xl-sd15-16ch_sd15_v1" print("Loading te adapter") @@ -28,13 +28,13 @@ is_pixart = "pixart" in model_path.lower() pipeline_class = StableDiffusionPipeline -transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16) +# transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16) if is_pixart: pipeline_class = PixArtSigmaPipeline if is_diffusers: - sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16) + sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16) else: sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16) @@ -50,7 +50,7 @@ if is_pixart: unet = sd.transformer unet_sd = sd.transformer.state_dict() else: - unet = sd.transformer + unet = sd.unet unet_sd = sd.unet.state_dict() diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index 9d56b019..a430d707 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -187,7 +187,7 @@ for epoch in range(args.epochs): batch: 'DataLoaderBatchDTO' img_batch = batch.tensor - img_batch = color_block_imgs(img_batch, neg1_1=True) + # img_batch = color_block_imgs(img_batch, neg1_1=True) chunks = torch.chunk(img_batch, batch_size, dim=0) # put them so they are size by side @@ -208,9 +208,9 @@ for epoch in range(args.epochs): # convert to image img = transforms.ToPILImage()(big_img) - show_img(img) + # show_img(img) - time.sleep(1.0) + # time.sleep(1.0) # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 44750c9b..27120572 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -39,7 +39,8 @@ from transformers import ( AutoImageProcessor, ConvNextModel, ConvNextForImageClassification, - ConvNextImageProcessor + ConvNextImageProcessor, + UMT5EncoderModel, LlamaTokenizerFast ) from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel @@ -165,6 +166,23 @@ class CustomAdapter(torch.nn.Module): # self.te.to = lambda *args, **kwargs: None self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path) + elif self.config.text_encoder_arch == 'pile-t5': + te_kwargs = {} + # te_kwargs['load_in_4bit'] = True + # te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + self.te = UMT5EncoderModel.from_pretrained( + self.config.text_encoder_path, + torch_dtype=torch_dtype, + **te_kwargs + ) + + # self.te.to = lambda *args, **kwargs: None + self.tokenizer = LlamaTokenizerFast.from_pretrained(self.config.text_encoder_path) + if self.tokenizer.pad_token is None: + self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) elif self.config.text_encoder_arch == 'clip': self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device, dtype=torch_dtype) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index a73fb001..e786e8e0 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -433,7 +433,15 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti ]) # this might take a while + print(f"Dataset: {self.dataset_path}") print(f" - Preprocessing image dimensions") + dataset_size_file = os.path.join(self.dataset_path, '.aitk_size.json') + if os.path.exists(dataset_size_file): + with open(dataset_size_file, 'r') as f: + self.size_database = json.load(f) + else: + self.size_database = {} + bad_count = 0 for file in tqdm(file_list): try: @@ -442,6 +450,7 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti path=file, dataset_config=dataset_config, dataloader_transforms=self.transform, + size_database=self.size_database, ) self.file_list.append(file_item) except Exception as e: @@ -450,6 +459,10 @@ class AiToolkitDataset(LatentCachingMixin, CLIPCachingMixin, BucketsMixin, Capti print(e) bad_count += 1 + # save the size database + with open(dataset_size_file, 'w') as f: + json.dump(self.size_database, f) + print(f" - Found {len(self.file_list)} images") # print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.dataset_path}" diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index b94ddddc..e936fb21 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -1,3 +1,4 @@ +import os import weakref from _weakref import ReferenceType from typing import TYPE_CHECKING, List, Union @@ -40,16 +41,22 @@ class FileItemDTO( ArgBreakMixin, ): def __init__(self, *args, **kwargs): - self.path = kwargs.get('path', None) + self.path = kwargs.get('path', '') self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) - # process width and height - try: - w, h = image_utils.get_image_size(self.path) - except image_utils.UnknownImageFormat: - print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ - f'This process is faster for png, jpeg') - img = exif_transpose(Image.open(self.path)) - h, w = img.size + size_database = kwargs.get('size_database', {}) + filename = os.path.basename(self.path) + if filename in size_database: + w, h = size_database[filename] + else: + # process width and height + try: + w, h = image_utils.get_image_size(self.path) + except image_utils.UnknownImageFormat: + print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + f'This process is faster for png, jpeg') + img = exif_transpose(Image.open(self.path)) + h, w = img.size + size_database[filename] = (w, h) self.width: int = w self.height: int = h self.dataloader_transforms = kwargs.get('dataloader_transforms', None) diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index bc182b70..267a0dde 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -243,7 +243,7 @@ class TEAdapter(torch.nn.Module): self.embeds_store = [] is_pixart = sd.is_pixart - if self.adapter_ref().config.text_encoder_arch == "t5": + if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5": self.token_size = self.te_ref().config.d_model else: self.token_size = self.te_ref().config.target_hidden_size @@ -388,13 +388,25 @@ class TEAdapter(torch.nn.Module): # ).input_ids.to(te.device) # outputs = te(input_ids=input_ids) # outputs = outputs.last_hidden_state - embeds, attention_mask = train_tools.encode_prompts_pixart( - tokenizer, - te, - text, - truncate=True, - max_length=self.adapter_ref().config.num_tokens, - ) + + if self.adapter_ref().config.text_encoder_arch == "pile-t5": + # just use aura pile + embeds, attention_mask = train_tools.encode_prompts_auraflow( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + + else: + embeds, attention_mask = train_tools.encode_prompts_pixart( + tokenizer, + te, + text, + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype) if self.text_projection is not None: # pool the output of embeds ignoring 0 in the attention mask