diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 9f5e1a5e..5ed8005f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1450,8 +1450,8 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() # self.step_num = 0 - print(f"Compiling Model") - torch.compile(self.sd.unet, dynamic=True) + # print(f"Compiling Model") + # torch.compile(self.sd.unet, dynamic=True) ################################################################### # TRAIN LOOP diff --git a/requirements.txt b/requirements.txt index f19e9c6e..6caa30b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,5 @@ prodigyopt controlnet_aux==0.0.7 python-dotenv bitsandbytes -xformers \ No newline at end of file +xformers +hf_transfer \ No newline at end of file diff --git a/run.py b/run.py index 8c767ec5..6f133081 100644 --- a/run.py +++ b/run.py @@ -1,4 +1,5 @@ import os +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" import sys from typing import Union, OrderedDict from dotenv import load_dotenv diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py new file mode 100644 index 00000000..9158384a --- /dev/null +++ b/testing/merge_in_text_encoder_adapter.py @@ -0,0 +1,105 @@ +import os + +import torch +from transformers import T5EncoderModel, T5Tokenizer +from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import json + +model_path = "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/objective_reality_v2.safetensors" +te_path = "google/flan-t5-xl" +te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" +output_path = "/home/jaret/Dev/models/hf/t5xl_sd15_v1" + +print("Loading te adapter") +te_aug_sd = load_file(te_aug_path) + +print("Loading model") +sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16) + +print("Loading Text Encoder") +# Load the text encoder +te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16) + +# patch it +sd.text_encoder = te +sd.tokenizer = T5Tokenizer.from_pretrained(te_path) + +unet_sd = sd.unet.state_dict() + +weight_idx = 1 + +new_cross_attn_dim = None + +print("Patching UNet") +for name in sd.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + pass + else: + layer_name = name.split(".processor")[0] + to_k_adapter = unet_sd[layer_name + ".to_k.weight"] + to_v_adapter = unet_sd[layer_name + ".to_v.weight"] + + te_aug_name = None + while True: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + if f"{te_aug_name}.weight" in te_aug_sd: + # increment so we dont redo it next time + weight_idx += 1 + break + else: + weight_idx += 1 + + if weight_idx > 1000: + raise ValueError("Could not find the next weight") + + unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"] + unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"] + + if new_cross_attn_dim is None: + new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1] + + +print("Saving unmodified model") +sd.save_pretrained( + output_path, + safe_serialization=True, +) + +# overwrite the unet +unet_folder = os.path.join(output_path, "unet") + +# move state_dict to cpu +unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()} + +meta = OrderedDict() +meta["format"] = "pt" + +print("Patching new unet") + +save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta) + +# load the json file +with open(os.path.join(unet_folder, "config.json"), 'r') as f: + config = json.load(f) + +config['cross_attention_dim'] = new_cross_attn_dim + +# save it +with open(os.path.join(unet_folder, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + +print("Done") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 132906a5..d5ab38b6 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -344,6 +344,7 @@ class ModelConfig: self._original_refiner_name_or_path = self.refiner_name_or_path self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) self.lora_path = kwargs.get('lora_path', None) + self.latent_space_version = kwargs.get('latent_space_version', None) # only for SDXL models for now self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index c27483c0..f59eb124 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -111,6 +111,7 @@ class CustomAdapter(torch.nn.Module): self.load_state_dict(loaded_state_dict, strict=False) def setup_adapter(self): + torch_dtype = get_torch_dtype(self.sd_ref().dtype) if self.adapter_type == 'photo_maker': sd = self.sd_ref() embed_dim = sd.unet.config['cross_attention_dim'] @@ -146,14 +147,23 @@ class CustomAdapter(torch.nn.Module): ) elif self.adapter_type == 'text_encoder': if self.config.text_encoder_arch == 't5': - self.te = T5EncoderModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device, - dtype=get_torch_dtype( - self.sd_ref().dtype)) + te_kwargs = {} + # te_kwargs['load_in_4bit'] = True + # te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + self.te = T5EncoderModel.from_pretrained( + self.config.text_encoder_path, + torch_dtype=torch_dtype, + **te_kwargs + ) + + # self.te.to = lambda *args, **kwargs: None self.tokenizer = T5Tokenizer.from_pretrained(self.config.text_encoder_path) elif self.config.text_encoder_arch == 'clip': self.te = CLIPTextModel.from_pretrained(self.config.text_encoder_path).to(self.sd_ref().unet.device, - dtype=get_torch_dtype( - self.sd_ref().dtype)) + dtype=torch_dtype) self.tokenizer = CLIPTokenizer.from_pretrained(self.config.text_encoder_path) else: raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}") @@ -531,7 +541,14 @@ class CustomAdapter(torch.nn.Module): has_been_preprocessed=False, is_unconditional=False, quad_count=4, + is_generating_samples=False, ) -> PromptEmbeds: + if self.adapter_type == 'text_encoder' and is_generating_samples: + # replace the prompt embed with ours + if is_unconditional: + return self.unconditional_embeds.clone() + return self.conditional_embeds.clone() + if self.adapter_type == 'ilora': return prompt_embeds diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 4755dd88..28e1c5ce 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -585,7 +585,7 @@ def get_dataloader_from_datasets( drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function - num_workers=4 + num_workers=8 ) else: data_loader = DataLoader( diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index c164a3e0..ccaf6bc9 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -361,7 +361,7 @@ class CaptionProcessingDTOMixin: caption = ', '.join(token_list) caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) - if self.dataset_config.random_triggers and len(self.dataset_config.random_triggers) > 0: + if self.dataset_config.random_triggers: num_triggers = self.dataset_config.random_triggers_max if num_triggers > 1: num_triggers = random.randint(0, num_triggers) @@ -369,6 +369,9 @@ class CaptionProcessingDTOMixin: if num_triggers > 0: # add random triggers for i in range(num_triggers): + + + caption = caption + ', ' + random.choice(self.dataset_config.random_triggers) if self.dataset_config.shuffle_tokens: @@ -1316,7 +1319,9 @@ class LatentCachingMixin: i = 0 for file_item in tqdm(self.file_list, desc=f'Caching latents{" to disk" if to_disk else ""}'): # set latent space version - if self.sd.is_xl: + if self.sd.model_config.latent_space_version is not None: + file_item.latent_space_version = self.sd.model_config.latent_space_version + elif self.sd.is_xl: file_item.latent_space_version = 'sdxl' else: file_item.latent_space_version = 'sd1' diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 928f8ebd..1fdbcb20 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -487,7 +487,7 @@ class IPAdapter(torch.nn.Module): attn_processor_names = [] for name in attn_processor_keys: - cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \ + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] @@ -540,9 +540,6 @@ class IPAdapter(torch.nn.Module): module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"] self.adapter_modules = torch.nn.ModuleList( [ - transformer.transformer_blocks[i].attn1.processor for i in - range(len(transformer.transformer_blocks)) - ] + [ transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) ]) diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index c1b77831..a8676ec3 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -6,8 +6,12 @@ import torch.nn.functional as F import weakref from typing import Union, TYPE_CHECKING -from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection + +from toolkit import train_tools from toolkit.paths import REPOS_ROOT +from toolkit.prompt_utils import PromptEmbeds + sys.path.append(REPOS_ROOT) from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 @@ -106,14 +110,14 @@ class TEAdapterAttnProcessor(nn.Module): # only use one TE or the other. If our adapter is active only use ours if self.is_active and self.conditional_embeds is not None: - adapter_hidden_states = self.conditional_embeds + adapter_hidden_states = self.conditional_embeds.text_embeds # check if we are doing unconditional if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]: # concat unconditional to match the hidden state batch size - if self.unconditional_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: - unconditional = torch.cat([self.unconditional_embeds] * adapter_hidden_states.shape[0], dim=0) + if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: + unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0) else: - unconditional = self.unconditional_embeds + unconditional = self.unconditional_embeds.text_embeds adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0) # for ip-adapter key = self.to_k_adapter(adapter_hidden_states) @@ -163,7 +167,7 @@ class TEAdapter(torch.nn.Module): self, adapter: 'CustomAdapter', sd: 'StableDiffusion', - te: Union[T5EncoderModel, CLIPTextModel], + te: Union[T5EncoderModel], tokenizer: CLIPTokenizer ): super(TEAdapter, self).__init__() @@ -178,6 +182,12 @@ class TEAdapter(torch.nn.Module): else: self.token_size = self.te_ref().config.hidden_size + # add text projection if is sdxl + self.text_projection = None + if sd.is_xl: + clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0] + self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False) + # init adapter modules attn_procs = {} unet_sd = sd.unet.state_dict() @@ -258,16 +268,48 @@ class TEAdapter(torch.nn.Module): te: T5EncoderModel = self.te_ref() tokenizer: T5Tokenizer = self.tokenizer_ref() - input_ids = tokenizer( + # input_ids = tokenizer( + # text, + # max_length=77, + # padding="max_length", + # truncation=True, + # return_tensors="pt", + # ).input_ids.to(te.device) + # outputs = te(input_ids=input_ids) + # outputs = outputs.last_hidden_state + embeds, attention_mask = train_tools.encode_prompts_pixart( + tokenizer, + te, text, - max_length=77, - padding="max_length", - truncation=True, - return_tensors="pt", - ).input_ids.to(te.device) - outputs = te(input_ids=input_ids) - outputs = outputs.last_hidden_state - return outputs + truncate=True, + max_length=self.adapter_ref().config.num_tokens, + ) + attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype) + if self.text_projection is not None: + # pool the output of embeds ignoring 0 in the attention mask + pooled_output = embeds * attn_mask_float.unsqueeze(-1) + + # reduce along dim 1 while maintaining batch and dim 2 + pooled_output_sum = pooled_output.sum(dim=1) + attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1) + + pooled_output = pooled_output_sum / attn_mask_sum + + pooled_embeds = self.text_projection(pooled_output) + + t5_embeds = PromptEmbeds( + (embeds, pooled_embeds), + attention_mask=attention_mask, + ).detach() + + else: + + t5_embeds = PromptEmbeds( + embeds, + attention_mask=attention_mask, + ).detach() + + return t5_embeds diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 44516834..b99569e7 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -672,6 +672,7 @@ class StableDiffusion: prompt_embeds=conditional_embeds, is_training=False, has_been_preprocessed=False, + is_generating_samples=True, ) unconditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, @@ -679,6 +680,7 @@ class StableDiffusion: is_training=False, has_been_preprocessed=False, is_unconditional=True, + is_generating_samples=True, ) if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: @@ -1324,6 +1326,20 @@ class StableDiffusion: attention_mask=attention_mask, ) + elif isinstance(self.text_encoder, T5EncoderModel): + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=77, # todo set this higher when not transfer learning + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + # do we want attn mask here? + # attention_mask=attention_mask, + ) else: return PromptEmbeds( train_tools.encode_prompts( diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 629d9ec5..e80c0649 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -665,7 +665,9 @@ def encode_prompts_pixart( prompt_attention_mask = text_inputs.attention_mask prompt_attention_mask = prompt_attention_mask.to(text_encoder.device) - prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), attention_mask=prompt_attention_mask) + text_input_ids = text_input_ids.to(text_encoder.device) + + prompt_embeds = text_encoder(text_input_ids, attention_mask=prompt_attention_mask) return prompt_embeds.last_hidden_state, prompt_attention_mask