mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Fixes and longer prompts
This commit is contained in:
@@ -447,29 +447,78 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def text_tokenize(
|
||||
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ!
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length: int = None,
|
||||
max_length_multiplier: int = 4,
|
||||
):
|
||||
return tokenizer(
|
||||
# allow fo up to 4x the max length for long prompts
|
||||
if max_length is None:
|
||||
if truncate:
|
||||
max_length = tokenizer.model_max_length
|
||||
else:
|
||||
# allow up to 4x the max length for long prompts
|
||||
max_length = tokenizer.model_max_length * max_length_multiplier
|
||||
|
||||
input_ids = tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding='max_length',
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if truncate or max_length == tokenizer.model_max_length:
|
||||
return input_ids
|
||||
else:
|
||||
# remove additional padding
|
||||
num_chunks = input_ids.shape[1] // tokenizer.model_max_length
|
||||
chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
|
||||
|
||||
# New list to store non-redundant chunks
|
||||
non_redundant_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
|
||||
non_redundant_chunks.append(chunk)
|
||||
|
||||
input_ids = torch.cat(non_redundant_chunks, dim=1)
|
||||
return input_ids
|
||||
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
||||
def text_encode_xl(
|
||||
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
|
||||
tokens: torch.FloatTensor,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_length: int = 77, # not sure what default to put here, always pass one?
|
||||
truncate: bool = True,
|
||||
):
|
||||
prompt_embeds = text_encoder(
|
||||
tokens.to(text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||
if truncate:
|
||||
# normal short prompt 77 tokens max
|
||||
prompt_embeds = text_encoder(
|
||||
tokens.to(text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||
else:
|
||||
# handle long prompts
|
||||
prompt_embeds_list = []
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
pooled_prompt_embeds = None
|
||||
for i in range(0, tokens.shape[-1], max_length):
|
||||
# todo run it through the in a single batch
|
||||
section_tokens = tokens[:, i: i + max_length]
|
||||
embeds = text_encoder(section_tokens, output_hidden_states=True)
|
||||
pooled_prompt_embed = embeds[0]
|
||||
if pooled_prompt_embeds is None:
|
||||
# we only want the first ( I think??)
|
||||
pooled_prompt_embeds = pooled_prompt_embed
|
||||
prompt_embed = embeds.hidden_states[-2] # always penultimate layer
|
||||
prompt_embeds_list.append(prompt_embed)
|
||||
|
||||
prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -485,7 +534,9 @@ def encode_prompts_xl(
|
||||
prompts2: Union[list[str], None],
|
||||
num_images_per_prompt: int = 1,
|
||||
use_text_encoder_1: bool = True, # sdxl
|
||||
use_text_encoder_2: bool = True # sdxl
|
||||
use_text_encoder_2: bool = True, # sdxl
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||
text_embeds_list = []
|
||||
@@ -502,9 +553,14 @@ def encode_prompts_xl(
|
||||
if idx == 1 and not use_text_encoder_2:
|
||||
prompt_list_to_use = ["" for _ in prompts]
|
||||
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
|
||||
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
||||
# set the max length for the next one
|
||||
if idx == 0:
|
||||
max_length = text_tokens_input_ids.shape[-1]
|
||||
|
||||
text_embeds, pooled_text_embeds = text_encode_xl(
|
||||
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
||||
text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
|
||||
truncate=truncate
|
||||
)
|
||||
|
||||
text_embeds_list.append(text_embeds)
|
||||
@@ -517,18 +573,36 @@ def encode_prompts_xl(
|
||||
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
||||
|
||||
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens):
|
||||
return text_encoder(tokens.to(text_encoder.device))[0]
|
||||
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
|
||||
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
||||
if max_length is None and not truncate:
|
||||
raise ValueError("max_length must be set if truncate is True")
|
||||
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
if truncate:
|
||||
return text_encoder(tokens)[0]
|
||||
else:
|
||||
# handle long prompts
|
||||
prompt_embeds_list = []
|
||||
for i in range(0, tokens.shape[-1], max_length):
|
||||
prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
return torch.cat(prompt_embeds_list, dim=1)
|
||||
|
||||
|
||||
def encode_prompts(
|
||||
tokenizer: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTokenizer',
|
||||
text_encoder: 'CLIPTextModel',
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
):
|
||||
text_tokens = text_tokenize(tokenizer, prompts)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens)
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length
|
||||
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
||||
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
|
||||
Reference in New Issue
Block a user