From cab8a1c7b8fc5969d6af58a8e9527703cf29a282 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 6 Jul 2024 13:00:21 -0600 Subject: [PATCH] WIP to add the caption_proj weight to pixart sigma TE adapter --- testing/merge_in_text_encoder_adapter.py | 90 ++++++-- testing/shrink_pixart.py | 76 +++++++ toolkit/lora_special.py | 2 + toolkit/models/LoRAFormer.py | 267 +++++++++++++++++++++++ toolkit/models/ilora.py | 8 +- toolkit/models/te_adapter.py | 69 ++++++ toolkit/network_mixins.py | 8 +- toolkit/stable_diffusion_model.py | 3 + 8 files changed, 500 insertions(+), 23 deletions(-) create mode 100644 testing/shrink_pixart.py create mode 100644 toolkit/models/LoRAFormer.py diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py index 6903a6ad..9fda8a90 100644 --- a/testing/merge_in_text_encoder_adapter.py +++ b/testing/merge_in_text_encoder_adapter.py @@ -2,15 +2,20 @@ import os import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import StableDiffusionPipeline, UNet2DConditionModel +from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel from safetensors.torch import load_file, save_file from collections import OrderedDict import json -model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" -te_path = "google/flan-t5-xl" -te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" -output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" +# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" +# te_path = "google/flan-t5-xl" +# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" +# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS" +te_path = "google/flan-t5-base" +te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS_t5base_raw" + print("Loading te adapter") te_aug_sd = load_file(te_aug_path) @@ -18,10 +23,18 @@ te_aug_sd = load_file(te_aug_path) print("Loading model") is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path) +# if "pixart" in model_path.lower(): +is_pixart = "pixart" in model_path.lower() + +pipeline_class = StableDiffusionPipeline + +if is_pixart: + pipeline_class = PixArtSigmaPipeline + if is_diffusers: - sd = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) + sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16) else: - sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16) + sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16) print("Loading Text Encoder") # Load the text encoder @@ -31,23 +44,49 @@ te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16) sd.text_encoder = te sd.tokenizer = T5Tokenizer.from_pretrained(te_path) -unet_sd = sd.unet.state_dict() +if is_pixart: + unet = sd.transformer + unet_sd = sd.transformer.state_dict() +else: + unet = sd.transformer + unet_sd = sd.unet.state_dict() -weight_idx = 1 + +if is_pixart: + weight_idx = 0 +else: + weight_idx = 1 new_cross_attn_dim = None -print("Patching UNet") -for name in sd.unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] +# count the num of params in state dict +start_params = sum([v.numel() for v in unet_sd.values()]) + +print("Building") +attn_processor_keys = [] +if is_pixart: + transformer: Transformer2DModel = unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn1") + # cross attention + attn_processor_keys.append(f"transformer_blocks.{i}.attn2") +else: + attn_processor_keys = list(unet.attn_processors.keys()) + +for name in attn_processor_keys: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith( + "attn1") else \ + unet.config['cross_attention_dim'] if name.startswith("mid_block"): - hidden_size = sd.unet.config['block_out_channels'][-1] + hidden_size = unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + hidden_size = list(reversed(unet.config['block_out_channels']))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) - hidden_size = sd.unet.config['block_out_channels'][block_id] + hidden_size = unet.config['block_out_channels'][block_id] + elif name.startswith("transformer"): + hidden_size = unet.config['cross_attention_dim'] else: # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") @@ -60,7 +99,10 @@ for name in sd.unet.attn_processors.keys(): te_aug_name = None while True: - te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + if is_pixart: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" + else: + te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter" if f"{te_aug_name}.weight" in te_aug_sd: # increment so we dont redo it next time weight_idx += 1 @@ -86,7 +128,10 @@ sd.save_pretrained( ) # overwrite the unet -unet_folder = os.path.join(output_path, "unet") +if is_pixart: + unet_folder = os.path.join(output_path, "transformer") +else: + unet_folder = os.path.join(output_path, "unet") # move state_dict to cpu unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()} @@ -94,7 +139,7 @@ unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()} meta = OrderedDict() meta["format"] = "pt" -print("Patching new unet") +print("Patching") save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta) @@ -104,8 +149,17 @@ with open(os.path.join(unet_folder, "config.json"), 'r') as f: config['cross_attention_dim'] = new_cross_attn_dim +if is_pixart: + config['caption_channels'] = te.config.d_model + # save it with open(os.path.join(unet_folder, "config.json"), 'w') as f: json.dump(config, f, indent=2) print("Done") + +new_params = sum([v.numel() for v in unet_sd.values()]) + +# print new and old params with , formatted +print(f"Old params: {start_params:,}") +print(f"New params: {new_params:,}") diff --git a/testing/shrink_pixart.py b/testing/shrink_pixart.py new file mode 100644 index 00000000..1cac2a53 --- /dev/null +++ b/testing/shrink_pixart.py @@ -0,0 +1,76 @@ +import os + +import torch +from transformers import T5EncoderModel, T5Tokenizer +from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel +from safetensors.torch import load_file, save_file +from collections import OrderedDict +import json + +# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" +# te_path = "google/flan-t5-xl" +# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" +# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors" +te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +# has 28 blocks +# keep block 0 and 27 + +new_state_dict = {} + +# move non blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27 + +current_idx = 0 +for i in range(28): + if i not in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]: + # todo merge in with previous block + for name in block_names: + try: + new_state_dict_key = name.format(idx=current_idx - 1) + old_state_dict_key = name.format(idx=i) + new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5) + except KeyError: + raise KeyError(f"KeyError: {name.format(idx=current_idx)}") + else: + for name in block_names: + new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)] + current_idx += 1 + + +# make sure they are all fp16 and on cpu +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +# porint comma formatted +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index bcb2d63d..a71a71b7 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -1,6 +1,7 @@ import copy import json import math +import weakref import os import re import sys @@ -59,6 +60,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): ToolkitModuleMixin.__init__(self, network=network) torch.nn.Module.__init__(self) self.lora_name = lora_name + self.orig_module_ref = weakref.ref(org_module) self.scalar = torch.tensor(1.0) # check if parent has bias. if not force use_bias to False if org_module.bias is None: diff --git a/toolkit/models/LoRAFormer.py b/toolkit/models/LoRAFormer.py new file mode 100644 index 00000000..78bb460d --- /dev/null +++ b/toolkit/models/LoRAFormer.py @@ -0,0 +1,267 @@ +import math +import weakref + +import torch +import torch.nn as nn +from typing import TYPE_CHECKING, List, Dict, Any +from toolkit.models.clip_fusion import ZipperBlock +from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler +import sys +from toolkit.paths import REPOS_ROOT +sys.path.append(REPOS_ROOT) +from ipadapter.ip_adapter.resampler import Resampler +from collections import OrderedDict + +if TYPE_CHECKING: + from toolkit.lora_special import LoRAModule + from toolkit.stable_diffusion_model import StableDiffusion + + +class TransformerBlock(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.ReLU(), + nn.Linear(dim_feedforward, d_model) + ) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward(self, x, cross_attn_input): + # Self-attention + attn_output, _ = self.self_attn(x, x, x) + x = self.norm1(x + attn_output) + + # Cross-attention + cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input) + x = self.norm2(x + cross_attn_output) + + # Feed-forward + ff_output = self.feed_forward(x) + x = self.norm3(x + ff_output) + + return x + + +class InstantLoRAMidModule(torch.nn.Module): + def __init__( + self, + index: int, + lora_module: 'LoRAModule', + instant_lora_module: 'InstantLoRAModule', + up_shape: list = None, + down_shape: list = None, + ): + super(InstantLoRAMidModule, self).__init__() + self.up_shape = up_shape + self.down_shape = down_shape + self.index = index + self.lora_module_ref = weakref.ref(lora_module) + self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.embed = None + + def down_forward(self, x, *args, **kwargs): + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] + + batch_size = x.shape[0] + + # unconditional + if down_weight.shape[0] * 2 == batch_size: + down_weight = torch.cat([down_weight] * 2, dim=0) + + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + +# Initialize the network +# num_blocks = 8 +# d_model = 1024 # Adjust as needed +# nhead = 16 # Adjust as needed +# dim_feedforward = 4096 # Adjust as needed +# latent_dim = 1695744 + +class LoRAFormer(torch.nn.Module): + def __init__( + self, + num_blocks, + d_model=1024, + nhead=16, + dim_feedforward=4096, + sd: 'StableDiffusion'=None, + ): + super(LoRAFormer, self).__init__() + # self.linear = torch.nn.Linear(2, 1) + self.sd_ref = weakref.ref(sd) + self.dim = sd.network.lora_dim + + # stores the projection vector. Grabbed by modules + self.img_embeds: List[torch.Tensor] = None + + # disable merging in. It is slower on inference + self.sd_ref().network.can_merge_in = False + + self.ilora_modules = torch.nn.ModuleList() + + lora_modules = self.sd_ref().network.get_all_modules() + + output_size = 0 + + self.embed_lengths = [] + self.weight_mapping = [] + + for idx, lora_module in enumerate(lora_modules): + module_dict = lora_module.state_dict() + down_shape = list(module_dict['lora_down.weight'].shape) + up_shape = list(module_dict['lora_up.weight'].shape) + + self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) + + module_size = math.prod(down_shape) + math.prod(up_shape) + output_size += module_size + self.embed_lengths.append(module_size) + + + # add a new mid module that will take the original forward and add a vector to it + # this will be used to add the vector to the original forward + instant_module = InstantLoRAMidModule( + idx, + lora_module, + self, + up_shape=up_shape, + down_shape=down_shape + ) + + self.ilora_modules.append(instant_module) + + # replace the LoRA forwards + lora_module.lora_down.forward = instant_module.down_forward + lora_module.lora_up.forward = instant_module.up_forward + + + self.output_size = output_size + + self.latent = nn.Parameter(torch.randn(1, output_size)) + self.latent_proj = nn.Linear(output_size, d_model) + self.blocks = nn.ModuleList([ + TransformerBlock(d_model, nhead, dim_feedforward) + for _ in range(num_blocks) + ]) + self.final_proj = nn.Linear(d_model, output_size) + + self.migrate_weight_mapping() + + def migrate_weight_mapping(self): + return + # # changes the names of the modules to common ones + # keymap = self.sd_ref().network.get_keymap() + # save_keymap = {} + # if keymap is not None: + # for ldm_key, diffusers_key in keymap.items(): + # # invert them + # save_keymap[diffusers_key] = ldm_key + # + # new_keymap = {} + # for key, value in self.weight_mapping: + # if key in save_keymap: + # new_keymap[save_keymap[key]] = value + # else: + # print(f"Key {key} not found in keymap") + # new_keymap[key] = value + # self.weight_mapping = new_keymap + # else: + # print("No keymap found. Using default names") + # return + + + def forward(self, img_embeds): + # expand token rank if only rank 2 + if len(img_embeds.shape) == 2: + img_embeds = img_embeds.unsqueeze(1) + + # resample the image embeddings + img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) + if len(img_embeds.shape) == 3: + # merge the heads + img_embeds = img_embeds.mean(dim=1) + + self.img_embeds = [] + # get all the slices + start = 0 + for length in self.embed_lengths: + self.img_embeds.append(img_embeds[:, start:start+length]) + start += length + + + def get_additional_save_metadata(self) -> Dict[str, Any]: + # save the weight mapping + return { + "weight_mapping": self.weight_mapping, + "num_heads": self.num_heads, + "vision_hidden_size": self.vision_hidden_size, + "head_dim": self.head_dim, + "vision_tokens": self.vision_tokens, + "output_size": self.output_size, + } + diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 7508a914..c9797495 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -156,10 +156,10 @@ class InstantLoRAMidModule(torch.nn.Module): weight_chunk = weight_chunk.view(self.down_shape) # check if is conv or linear if len(weight_chunk.shape) == 4: - padding = 0 - if weight_chunk.shape[-1] == 3: - padding = 1 - x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + org_module = self.lora_module_ref().orig_module_ref() + stride = org_module.stride + padding = org_module.padding + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride) else: # run a simple linear layer with the down weight x_chunk = x_chunk @ weight_chunk.T diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index 8daf6305..7d147585 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -6,7 +6,9 @@ import torch.nn.functional as F import weakref from typing import Union, TYPE_CHECKING + from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection +from diffusers.models.embeddings import PixArtAlphaTextProjection from toolkit import train_tools from toolkit.paths import REPOS_ROOT @@ -17,11 +19,71 @@ sys.path.append(REPOS_ROOT) from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 + if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion from toolkit.custom_adapter import CustomAdapter +class TEAdapterCaptionProjection(nn.Module): + def __init__(self, caption_channels, adapter: 'TEAdapter'): + super().__init__() + in_features = caption_channels + self.adapter_ref: weakref.ref = weakref.ref(adapter) + sd = adapter.sd_ref() + self.parent_module_ref = weakref.ref(sd.transformer.caption_projection) + parent_module = self.parent_module_ref() + self.linear_1 = nn.Linear( + in_features=in_features, + out_features=parent_module.linear_1.out_features, + bias=True + ) + self.linear_2 = nn.Linear( + in_features=parent_module.linear_2.in_features, + out_features=parent_module.linear_2.out_features, + bias=True + ) + + # save the orig forward + parent_module.linear_1.orig_forward = parent_module.linear_1.forward + parent_module.linear_2.orig_forward = parent_module.linear_2.forward + + # replace original forward + parent_module.orig_forward = parent_module.forward + parent_module.forward = self.forward + + + @property + def is_active(self): + return self.adapter_ref().is_active + + @property + def unconditional_embeds(self): + return self.adapter_ref().adapter_ref().unconditional_embeds + + @property + def conditional_embeds(self): + return self.adapter_ref().adapter_ref().conditional_embeds + + def forward(self, caption): + if self.is_active and self.conditional_embeds is not None: + adapter_hidden_states = self.conditional_embeds.text_embeds + # check if we are doing unconditional + if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]: + # concat unconditional to match the hidden state batch size + if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1: + unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0) + else: + unconditional = self.unconditional_embeds.text_embeds + adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0) + hidden_states = self.linear_1(adapter_hidden_states) + hidden_states = self.parent_module_ref().act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + else: + return self.parent_module_ref().orig_forward(caption) + + class TEAdapterAttnProcessor(nn.Module): r""" Attention processor for Custom TE for PyTorch 2.0. @@ -177,6 +239,8 @@ class TEAdapter(torch.nn.Module): self.te_ref: weakref.ref = weakref.ref(te) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) self.adapter_modules = [] + self.caption_projection = None + self.embeds_store = [] is_pixart = sd.is_pixart if self.adapter_ref().config.text_encoder_arch == "t5": @@ -297,6 +361,11 @@ class TEAdapter(torch.nn.Module): transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) ]) + self.caption_projection = TEAdapterCaptionProjection( + caption_channels=self.token_size, + adapter=self, + ) + else: sd.unet.set_attn_processor(attn_procs) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 8f996f74..01ec66ea 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -289,7 +289,13 @@ class ToolkitModuleMixin: scaled_lora_weight = lora_weight * scale scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight) - x = org_forwarded + scaled_lora_output + try: + x = org_forwarded + scaled_lora_output + except RuntimeError as e: + print(e) + print(org_forwarded.size()) + print(scaled_lora_output.size()) + raise e return x def enable_gradient_checkpointing(self: Module): diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 9b14b84a..ac552e3e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -309,6 +309,9 @@ class StableDiffusion: main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS" if self.model_config.is_pixart_sigma: main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers" + + main_model_path = model_path + # load the TE in 8bit mode text_encoder = T5EncoderModel.from_pretrained( main_model_path,