mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
WIP to add the caption_proj weight to pixart sigma TE adapter
This commit is contained in:
@@ -6,7 +6,9 @@ import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -17,11 +19,71 @@ sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
class TEAdapterCaptionProjection(nn.Module):
|
||||
def __init__(self, caption_channels, adapter: 'TEAdapter'):
|
||||
super().__init__()
|
||||
in_features = caption_channels
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
sd = adapter.sd_ref()
|
||||
self.parent_module_ref = weakref.ref(sd.transformer.caption_projection)
|
||||
parent_module = self.parent_module_ref()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=parent_module.linear_1.out_features,
|
||||
bias=True
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
in_features=parent_module.linear_2.in_features,
|
||||
out_features=parent_module.linear_2.out_features,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# save the orig forward
|
||||
parent_module.linear_1.orig_forward = parent_module.linear_1.forward
|
||||
parent_module.linear_2.orig_forward = parent_module.linear_2.forward
|
||||
|
||||
# replace original forward
|
||||
parent_module.orig_forward = parent_module.forward
|
||||
parent_module.forward = self.forward
|
||||
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
@property
|
||||
def unconditional_embeds(self):
|
||||
return self.adapter_ref().adapter_ref().unconditional_embeds
|
||||
|
||||
@property
|
||||
def conditional_embeds(self):
|
||||
return self.adapter_ref().adapter_ref().conditional_embeds
|
||||
|
||||
def forward(self, caption):
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
adapter_hidden_states = self.conditional_embeds.text_embeds
|
||||
# check if we are doing unconditional
|
||||
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]:
|
||||
# concat unconditional to match the hidden state batch size
|
||||
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
|
||||
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
|
||||
else:
|
||||
unconditional = self.unconditional_embeds.text_embeds
|
||||
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
|
||||
hidden_states = self.linear_1(adapter_hidden_states)
|
||||
hidden_states = self.parent_module_ref().act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
else:
|
||||
return self.parent_module_ref().orig_forward(caption)
|
||||
|
||||
|
||||
class TEAdapterAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Attention processor for Custom TE for PyTorch 2.0.
|
||||
@@ -177,6 +239,8 @@ class TEAdapter(torch.nn.Module):
|
||||
self.te_ref: weakref.ref = weakref.ref(te)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
self.adapter_modules = []
|
||||
self.caption_projection = None
|
||||
self.embeds_store = []
|
||||
is_pixart = sd.is_pixart
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
@@ -297,6 +361,11 @@ class TEAdapter(torch.nn.Module):
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
self.caption_projection = TEAdapterCaptionProjection(
|
||||
caption_channels=self.token_size,
|
||||
adapter=self,
|
||||
)
|
||||
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
Reference in New Issue
Block a user