Added caching to image sizes so we dont do it every time.

This commit is contained in:
Jaret Burkett
2024-07-15 19:07:41 -06:00
parent e4558dff4b
commit 58dffd43a8
7 changed files with 90 additions and 34 deletions

View File

@@ -243,7 +243,7 @@ class TEAdapter(torch.nn.Module):
self.embeds_store = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5":
if self.adapter_ref().config.text_encoder_arch == "t5" or self.adapter_ref().config.text_encoder_arch == "pile-t5":
self.token_size = self.te_ref().config.d_model
else:
self.token_size = self.te_ref().config.target_hidden_size
@@ -388,13 +388,25 @@ class TEAdapter(torch.nn.Module):
# ).input_ids.to(te.device)
# outputs = te(input_ids=input_ids)
# outputs = outputs.last_hidden_state
embeds, attention_mask = train_tools.encode_prompts_pixart(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
if self.adapter_ref().config.text_encoder_arch == "pile-t5":
# just use aura pile
embeds, attention_mask = train_tools.encode_prompts_auraflow(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
else:
embeds, attention_mask = train_tools.encode_prompts_pixart(
tokenizer,
te,
text,
truncate=True,
max_length=self.adapter_ref().config.num_tokens,
)
attn_mask_float = attention_mask.to(embeds.device, dtype=embeds.dtype)
if self.text_projection is not None:
# pool the output of embeds ignoring 0 in the attention mask