mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Bug fixes and minor features
This commit is contained in:
@@ -6,8 +6,12 @@ import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
@@ -106,14 +110,14 @@ class TEAdapterAttnProcessor(nn.Module):
|
||||
|
||||
# only use one TE or the other. If our adapter is active only use ours
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
adapter_hidden_states = self.conditional_embeds.text_embeds
|
||||
# check if we are doing unconditional
|
||||
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != encoder_hidden_states.shape[0]:
|
||||
# concat unconditional to match the hidden state batch size
|
||||
if self.unconditional_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
|
||||
unconditional = torch.cat([self.unconditional_embeds] * adapter_hidden_states.shape[0], dim=0)
|
||||
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
|
||||
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
|
||||
else:
|
||||
unconditional = self.unconditional_embeds
|
||||
unconditional = self.unconditional_embeds.text_embeds
|
||||
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
|
||||
# for ip-adapter
|
||||
key = self.to_k_adapter(adapter_hidden_states)
|
||||
@@ -163,7 +167,7 @@ class TEAdapter(torch.nn.Module):
|
||||
self,
|
||||
adapter: 'CustomAdapter',
|
||||
sd: 'StableDiffusion',
|
||||
te: Union[T5EncoderModel, CLIPTextModel],
|
||||
te: Union[T5EncoderModel],
|
||||
tokenizer: CLIPTokenizer
|
||||
):
|
||||
super(TEAdapter, self).__init__()
|
||||
@@ -178,6 +182,12 @@ class TEAdapter(torch.nn.Module):
|
||||
else:
|
||||
self.token_size = self.te_ref().config.hidden_size
|
||||
|
||||
# add text projection if is sdxl
|
||||
self.text_projection = None
|
||||
if sd.is_xl:
|
||||
clip_with_projection: CLIPTextModelWithProjection = sd.text_encoder[0]
|
||||
self.text_projection = nn.Linear(te.config.hidden_size, clip_with_projection.config.projection_dim, bias=False)
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
@@ -258,16 +268,48 @@ class TEAdapter(torch.nn.Module):
|
||||
te: T5EncoderModel = self.te_ref()
|
||||
tokenizer: T5Tokenizer = self.tokenizer_ref()
|
||||
|
||||
input_ids = tokenizer(
|
||||
# input_ids = tokenizer(
|
||||
# text,
|
||||
# max_length=77,
|
||||
# padding="max_length",
|
||||
# truncation=True,
|
||||
# return_tensors="pt",
|
||||
# ).input_ids.to(te.device)
|
||||
# outputs = te(input_ids=input_ids)
|
||||
# outputs = outputs.last_hidden_state
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
max_length=77,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(te.device)
|
||||
outputs = te(input_ids=input_ids)
|
||||
outputs = outputs.last_hidden_state
|
||||
return outputs
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if self.text_projection is not None:
|
||||
# pool the output of embeds ignoring 0 in the attention mask
|
||||
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
|
||||
|
||||
# reduce along dim 1 while maintaining batch and dim 2
|
||||
pooled_output_sum = pooled_output.sum(dim=1)
|
||||
attn_mask_sum = attn_mask_float.sum(dim=1).unsqueeze(-1)
|
||||
|
||||
pooled_output = pooled_output_sum / attn_mask_sum
|
||||
|
||||
pooled_embeds = self.text_projection(pooled_output)
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
(embeds, pooled_embeds),
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
else:
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
embeds,
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
return t5_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user