Bug fixes and minor features

This commit is contained in:
Jaret Burkett
2024-04-25 06:14:31 -06:00
parent 5a70b7f38d
commit 5da3613e0b
12 changed files with 218 additions and 31 deletions

View File

@@ -6,8 +6,12 @@ import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import PromptEmbeds
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
@@ -106,14 +110,14 @@ class TEAdapterAttnProcessor(nn.Module):
# only use one TE or the other. If our adapter is active only use ours
if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds
adapter_hidden_states = self.conditional_embeds.text_embeds
# check if we are doing unconditional
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]:
# concat unconditional to match the hidden state batch size
if self.unconditional_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
unconditional = torch.cat([self.unconditional_embeds] * adapter_hidden_states.shape[0], dim=0)
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
else:
unconditional = self.unconditional_embeds
unconditional = self.unconditional_embeds.text_embeds
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
# for ip-adapter
key = self.to_k_adapter(adapter_hidden_states)
@@ -163,7 +167,7 @@ class TEAdapter(torch.nn.Module):
self,
adapter: 'CustomAdapter',
sd: 'StableDiffusion',
te: Union[T5EncoderModel, CLIPTextModel],
te: Union[T5EncoderModel],
tokenizer: CLIPTokenizer
):
super(TEAdapter, self).__init__()
@@ -178,6 +182,12 @@ class TEAdapter(torch.nn.Module):
else:
self.token_size = self.te_ref().config.hidden_size
# add text projection if is sdxl
self.text_projection = None
if sd.is_xl:
clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0]
self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False)
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
@@ -258,16 +268,48 @@ class TEAdapter(torch.nn.Module):
te: T5EncoderModel = self.te_ref()
tokenizer: T5Tokenizer = self.tokenizer_ref()
input_ids = tokenizer(
# input_ids = tokenizer(
# text,
# max_length=77,
# padding="max_length",
# truncation=True,
# return_tensors="pt",
# ).input_ids.to(te.device)
# outputs = te(input_ids=input_ids)
# outputs = outputs.last_hidden_state
embeds, attention_mask = train_tools.encode_prompts_pixart(
tokenizer,
te,
text,
max_length=77,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids.to(te.device)
outputs = te(input_ids=input_ids)
outputs = outputs.last_hidden_state
return outputs
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if self.text_projection is not None:
# pool the output of embeds ignoring 0 in the attention mask
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
# reduce along dim 1 while maintaining batch and dim 2
pooled_output_sum = pooled_output.sum(dim=1)
attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1)
pooled_output = pooled_output_sum / attn_mask_sum
pooled_embeds = self.text_projection(pooled_output)
t5_embeds = PromptEmbeds(
(embeds, pooled_embeds),
attention_mask=attention_mask,
).detach()
else:
t5_embeds = PromptEmbeds(
embeds,
attention_mask=attention_mask,
).detach()
return t5_embeds