mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Various features and fixes. Too much brain fog to do a proper description
This commit is contained in:
@@ -21,7 +21,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
@@ -202,6 +202,10 @@ class TEAdapterAttnProcessor(nn.Module):
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
# remove attn mask if doing clip
|
||||
if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip":
|
||||
attention_mask = None
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
@@ -246,7 +250,7 @@ class TEAdapter(torch.nn.Module):
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
self.token_size = self.te_ref().config.d_model
|
||||
else:
|
||||
self.token_size = self.te_ref().config.target_hidden_size
|
||||
self.token_size = self.te_ref().config.hidden_size
|
||||
|
||||
# add text projection if is sdxl
|
||||
self.text_projection = None
|
||||
@@ -388,8 +392,17 @@ class TEAdapter(torch.nn.Module):
|
||||
# ).input_ids.to(te.device)
|
||||
# outputs = te(input_ids=input_ids)
|
||||
# outputs = outputs.last_hidden_state
|
||||
if self.adapter_ref().config.text_encoder_arch == "clip":
|
||||
embeds = train_tools.encode_prompts(
|
||||
tokenizer,
|
||||
te,
|
||||
text,
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attention_mask = torch.ones(embeds.shape[:2], device=embeds.device)
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
elif self.adapter_ref().config.text_encoder_arch == "pile-t5":
|
||||
# just use aura pile
|
||||
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
||||
tokenizer,
|
||||
@@ -407,7 +420,8 @@ class TEAdapter(torch.nn.Module):
|
||||
truncate=True,
|
||||
max_length=self.adapter_ref().config.num_tokens,
|
||||
)
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if attention_mask is not None:
|
||||
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
|
||||
if self.text_projection is not None:
|
||||
# pool the output of embeds ignoring 0 in the attention mask
|
||||
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
|
||||
@@ -420,19 +434,19 @@ class TEAdapter(torch.nn.Module):
|
||||
|
||||
pooled_embeds = self.text_projection(pooled_output)
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
prompt_embeds = PromptEmbeds(
|
||||
(embeds, pooled_embeds),
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
else:
|
||||
|
||||
t5_embeds = PromptEmbeds(
|
||||
prompt_embeds = PromptEmbeds(
|
||||
embeds,
|
||||
attention_mask=attention_mask,
|
||||
).detach()
|
||||
|
||||
return t5_embeds
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user