Various features and fixes. Too much brain fog to do a proper description

This commit is contained in:
Jaret Burkett
2024-07-18 07:34:14 -06:00
parent 58dffd43a8
commit 11e426fdf1
6 changed files with 119 additions and 25 deletions

View File

@@ -21,7 +21,7 @@ from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
from toolkit.custom_adapter import CustomAdapter
@@ -202,6 +202,10 @@ class TEAdapterAttnProcessor(nn.Module):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# remove attn mask if doing clip
if self.adapter_ref().adapter_ref().config.text_encoder_arch == "clip":
attention_mask = None
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
@@ -246,7 +250,7 @@ class TEAdapter(torch.nn.Module):
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
self.token_size = self.te_ref().config.d_model
else:
self.token_size = self.te_ref().config.target_hidden_size
self.token_size = self.te_ref().config.hidden_size
# add text projection if is sdxl
self.text_projection = None
@@ -388,8 +392,17 @@ class TEAdapter(torch.nn.Module):
# ).input_ids.to(te.device)
# outputs = te(input_ids=input_ids)
# outputs = outputs.last_hidden_state
if self.adapter_ref().config.text_encoder_arch == "clip":
embeds = train_tools.encode_prompts(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attention_mask = torch.ones(embeds.shape[:2], device=embeds.device)
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
elif self.adapter_ref().config.text_encoder_arch == "pile-t5":
# just use aura pile
embeds, attention_mask = train_tools.encode_prompts_auraflow(
tokenizer,
@@ -407,7 +420,8 @@ class TEAdapter(torch.nn.Module):
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if attention_mask is not None:
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if self.text_projection is not None:
# pool the output of embeds ignoring 0 in the attention mask
pooled_output = embeds * attn_mask_float.unsqueeze(-1)
@@ -420,19 +434,19 @@ class TEAdapter(torch.nn.Module):
pooled_embeds = self.text_projection(pooled_output)
t5_embeds = PromptEmbeds(
prompt_embeds = PromptEmbeds(
(embeds, pooled_embeds),
attention_mask=attention_mask,
).detach()
else:
t5_embeds = PromptEmbeds(
prompt_embeds = PromptEmbeds(
embeds,
attention_mask=attention_mask,
).detach()
return t5_embeds
return prompt_embeds