mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added flux training. Still a WIP. Wont train right without rectified flow working right
This commit is contained in:
@@ -3,7 +3,7 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
import sys
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
@@ -766,6 +766,73 @@ def encode_prompts_auraflow(
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def encode_prompts_flux(
|
||||
tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']],
|
||||
text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']],
|
||||
prompts: list[str],
|
||||
truncate: bool = True,
|
||||
max_length=None,
|
||||
dropout_prob=0.0,
|
||||
):
|
||||
if max_length is None:
|
||||
max_length = 512
|
||||
|
||||
if dropout_prob > 0.0:
|
||||
# randomly drop out prompts
|
||||
prompts = [
|
||||
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
|
||||
]
|
||||
|
||||
device = text_encoder[0].device
|
||||
dtype = text_encoder[0].dtype
|
||||
|
||||
batch_size = len(prompts)
|
||||
|
||||
# clip
|
||||
text_inputs = tokenizer[0](
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=tokenizer[0].model_max_length,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
# Use pooled output of CLIPTextModel
|
||||
pooled_prompt_embeds = prompt_embeds.pooler_output
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# T5
|
||||
text_inputs = tokenizer[1](
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
|
||||
dtype = text_encoder[1].dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
|
||||
# prompt_embeds = prompt_embeds * prompt_attention_mask
|
||||
# _, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# they dont do prompt attention mask?
|
||||
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
# for XL
|
||||
def get_add_time_ids(
|
||||
|
||||
Reference in New Issue
Block a user