From 49c41e6a5fc4106a50d0e4fe02b3cb21a8c55de9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 21 Feb 2024 04:51:52 -0700 Subject: [PATCH] Bug fixes. allow for random negative prompts --- extensions_built_in/sd_trainer/SDTrainer.py | 31 +++++++++- scripts/make_diffusers_model.py | 4 ++ scripts/patch_te_adapter.py | 42 +++++++++++++ scripts/repair_dataset_folder.py | 65 +++++++++++++++++++++ toolkit/config_modules.py | 1 + toolkit/ip_adapter.py | 17 +++++- toolkit/models/te_adapter.py | 10 +++- toolkit/stable_diffusion_model.py | 2 +- 8 files changed, 166 insertions(+), 6 deletions(-) create mode 100644 scripts/patch_te_adapter.py create mode 100644 scripts/repair_dataset_folder.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ba74286a..9e713475 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,3 +1,4 @@ +import os import random from collections import OrderedDict from typing import Union, Literal, List, Optional @@ -51,6 +52,8 @@ class SDTrainer(BaseSDTrainProcess): self.taesd: Optional[AutoencoderTiny] = None self._clip_image_embeds_unconditional: Union[List[str], None] = None + self.negative_prompt_pool: Union[List[str], None] = None + self.batch_negative_prompt: Union[List[str], None] = None def before_model_load(self): pass @@ -108,6 +111,16 @@ class SDTrainer(BaseSDTrainProcess): self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + if self.train_config.negative_prompt is not None: + if os.path.exists(self.train_config.negative_prompt): + with open(self.train_config.negative_prompt, 'r') as f: + self.negative_prompt_pool = f.readlines() + # remove empty + self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] + else: + # single prompt + self.negative_prompt_pool = [self.train_config.negative_prompt] + def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): # to process turbo learning, we make one big step from our current timestep to the end # we then denoise the prediction on that remaining step and target our loss to our target latents @@ -781,6 +794,18 @@ class SDTrainer(BaseSDTrainProcess): batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] if self.adapter and isinstance(self.adapter, CustomAdapter): # condition the prompt @@ -1030,7 +1055,8 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.do_cfg: # todo only do one and repeat it unconditional_embeds = self.sd.encode_prompt( - ["" for _ in range(noisy_latents.shape[0])], + self.batch_negative_prompt, + self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, @@ -1050,9 +1076,8 @@ class SDTrainer(BaseSDTrainProcess): self.device_torch, dtype=dtype) if self.train_config.do_cfg: - # todo only do one and repeat it unconditional_embeds = self.sd.encode_prompt( - ["" for _ in range(noisy_latents.shape[0])], + self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, diff --git a/scripts/make_diffusers_model.py b/scripts/make_diffusers_model.py index 1ec6c93a..4536a921 100644 --- a/scripts/make_diffusers_model.py +++ b/scripts/make_diffusers_model.py @@ -1,5 +1,9 @@ import argparse from collections import OrderedDict +import sys +import os +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(ROOT_DIR) import torch diff --git a/scripts/patch_te_adapter.py b/scripts/patch_te_adapter.py new file mode 100644 index 00000000..7249a46d --- /dev/null +++ b/scripts/patch_te_adapter.py @@ -0,0 +1,42 @@ +import torch +from safetensors.torch import save_file, load_file +from collections import OrderedDict +meta = OrderedDict() +meta["format"] ="pt" + +attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors") +state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors") + +attn_list = [] +for key, value in state_dict.items(): + if "attn1" in key: + attn_list.append(key) + +attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor'] + +adapter_names = [] +for i in range(100): + if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict: + adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter") + + +for i in range(len(adapter_names)): + adapter_name = adapter_names[i] + attn_name = attn_names[i] + adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight' + adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight' + state_k_name = attn_name.replace(".processor", ".to_k.weight") + state_v_name = attn_name.replace(".processor", ".to_v.weight") + if adapter_k_name in attn_dict: + state_dict[state_k_name] = attn_dict[adapter_k_name] + state_dict[state_v_name] = attn_dict[adapter_v_name] + else: + print("adapter_k_name", adapter_k_name) + print("state_k_name", state_k_name) + +for key, value in state_dict.items(): + state_dict[key] = value.cpu().to(torch.float16) + +save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta) + +print("Done") diff --git a/scripts/repair_dataset_folder.py b/scripts/repair_dataset_folder.py new file mode 100644 index 00000000..ad9d2775 --- /dev/null +++ b/scripts/repair_dataset_folder.py @@ -0,0 +1,65 @@ +import argparse +from PIL import Image +from PIL.ImageOps import exif_transpose +from tqdm import tqdm +import os + +parser = argparse.ArgumentParser(description='Process some images.') +parser.add_argument("input_folder", type=str, help="Path to folder containing images") + +args = parser.parse_args() + +img_types = ['.jpg', '.jpeg', '.png', '.webp'] + +# find all images in the input folder +images = [] +for root, _, files in os.walk(args.input_folder): + for file in files: + if file.lower().endswith(tuple(img_types)): + images.append(os.path.join(root, file)) +print(f"Found {len(images)} images") + +num_skipped = 0 +num_repaired = 0 +num_deleted = 0 + +pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image") +for img_path in images: + filename = os.path.basename(img_path) + filename_no_ext, file_extension = os.path.splitext(filename) + # if it is jpg, ignore + if file_extension.lower() == '.jpg': + num_skipped += 1 + pbar.update(1) + + continue + + try: + img = Image.open(img_path) + except Exception as e: + print(f"Error opening {img_path}: {e}") + # delete it + os.remove(img_path) + num_deleted += 1 + pbar.update(1) + pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") + continue + + + try: + img = exif_transpose(img) + except Exception as e: + print(f"Error rotating {img_path}: {e}") + + new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg') + + img = img.convert("RGB") + img.save(new_path, quality=95) + # remove the old file + os.remove(img_path) + num_repaired += 1 + pbar.update(1) + # update pbar + pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}") + +print("Done") \ No newline at end of file diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1768697d..dec6803d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -240,6 +240,7 @@ class TrainConfig: self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) self.negative_prompt = kwargs.get('negative_prompt', None) + self.max_negative_prompts = kwargs.get('max_negative_prompts', 1) # multiplier applied to loos on regularization images self.reg_weight = kwargs.get('reg_weight', 1.0) self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index c3e2d33b..0ab619dd 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -322,7 +322,7 @@ class IPAdapter(torch.nn.Module): elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 - embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \ + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \ self.image_encoder.config.hidden_sizes[-1] image_encoder_state_dict = self.image_encoder.state_dict() @@ -340,6 +340,10 @@ class IPAdapter(torch.nn.Module): dim = 4096 output_dim = 4096 + if self.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = self.image_encoder.config.hidden_sizes[-1] + # ip-adapter-plus image_proj_model = Resampler( dim=dim, @@ -406,6 +410,8 @@ class IPAdapter(torch.nn.Module): else: attn_processor_keys = list(sd.unet.attn_processors.keys()) + attn_processor_names = [] + for name in attn_processor_keys: cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \ sd.unet.config['cross_attention_dim'] @@ -446,6 +452,9 @@ class IPAdapter(torch.nn.Module): } attn_procs[name].load_state_dict(weights) + attn_processor_names.append(name) + print(f"Attn Processors") + print(attn_processor_names) if self.sd_ref().is_pixart: # we have to set them ourselves transformer: Transformer2DModel = sd.unet @@ -690,6 +699,12 @@ class IPAdapter(torch.nn.Module): else: clip_image_embeds = clip_output.image_embeds + if self.config.image_encoder_arch.startswith('convnext'): + # flatten the width height layers to make the token space + clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1) + # rearrange to (batch, tokens, size) + clip_image_embeds = clip_image_embeds.permute(0, 2, 1) + if self.config.quad_image: # get the outputs of the quat chunks = clip_image_embeds.chunk(quad_count, dim=0) diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index 10198642..0a1acada 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -171,11 +171,19 @@ class TEAdapter(torch.nn.Module): self.te_ref: weakref.ref = weakref.ref(te) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) - self.token_size = self.te_ref().config.d_model + if self.adapter_ref().config.text_encoder_arch == "t5": + self.token_size = self.te_ref().config.d_model + else: + self.token_size = self.te_ref().config.hidden_size # init adapter modules attn_procs = {} unet_sd = sd.unet.state_dict() + attn_dict_map = { + + } + module_idx = 0 + attn_processors_list = list(sd.unet.attn_processors.keys()) for name in sd.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 61139f8d..5694a146 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -287,7 +287,7 @@ class StableDiffusion: load_safety_checker=False, requires_safety_checker=False, torch_dtype=self.torch_dtype, - safety_checker=False, + safety_checker=None, **load_args ).to(self.device_torch) flush()