WIP to add the caption_proj weight to pixart sigma TE adapter

This commit is contained in:
Jaret Burkett
2024-07-06 13:00:21 -06:00
parent acb06d6ff3
commit cab8a1c7b8
8 changed files with 500 additions and 23 deletions

View File

@@ -6,7 +6,9 @@ import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from diffusers.models.embeddings import PixArtAlphaTextProjection
from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
@@ -17,11 +19,71 @@ sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter
class TEAdapterCaptionProjection(nn.Module):
def __init__(self, caption_channels, adapter: 'TEAdapter'):
super().__init__()
in_features = caption_channels
self.adapter_ref: weakref.ref = weakref.ref(adapter)
sd = adapter.sd_ref()
self.parent_module_ref = weakref.ref(sd.transformer.caption_projection)
parent_module = self.parent_module_ref()
self.linear_1 = nn.Linear(
in_features=in_features,
out_features=parent_module.linear_1.out_features,
bias=True
)
self.linear_2 = nn.Linear(
in_features=parent_module.linear_2.in_features,
out_features=parent_module.linear_2.out_features,
bias=True
)
# save the orig forward
parent_module.linear_1.orig_forward = parent_module.linear_1.forward
parent_module.linear_2.orig_forward = parent_module.linear_2.forward
# replace original forward
parent_module.orig_forward = parent_module.forward
parent_module.forward = self.forward
@property
def is_active(self):
return self.adapter_ref().is_active
@property
def unconditional_embeds(self):
return self.adapter_ref().adapter_ref().unconditional_embeds
@property
def conditional_embeds(self):
return self.adapter_ref().adapter_ref().conditional_embeds
def forward(self, caption):
if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds.text_embeds
# check if we are doing unconditional
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]:
# concat unconditional to match the hidden state batch size
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
else:
unconditional = self.unconditional_embeds.text_embeds
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
hidden_states = self.linear_1(adapter_hidden_states)
hidden_states = self.parent_module_ref().act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
else:
return self.parent_module_ref().orig_forward(caption)
class TEAdapterAttnProcessor(nn.Module):
r"""
Attention processor for Custom TE for PyTorch 2.0.
@@ -177,6 +239,8 @@ class TEAdapter(torch.nn.Module):
self.te_ref: weakref.ref = weakref.ref(te)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.adapter_modules = []
self.caption_projection = None
self.embeds_store = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5":
@@ -297,6 +361,11 @@ class TEAdapter(torch.nn.Module):
transformer.transformer_blocks[i].attn2.processor for i in
range(len(transformer.transformer_blocks))
])
self.caption_projection = TEAdapterCaptionProjection(
caption_channels=self.token_size,
adapter=self,
)
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())