Added flux training. Still a WIP. Wont train right without rectified flow working right

This commit is contained in:
Jaret Burkett
2024-08-02 15:00:30 -06:00
parent 03613c523f
commit 87ba867fdc
6 changed files with 292 additions and 15 deletions

View File

@@ -3,7 +3,7 @@ import hashlib
import json
import os
import time
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Union, List
import sys
from torch.cuda.amp import GradScaler
@@ -766,6 +766,73 @@ def encode_prompts_auraflow(
return prompt_embeds, prompt_attention_mask
def encode_prompts_flux(
tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']],
text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']],
prompts: list[str],
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
):
if max_length is None:
max_length = 512
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
device = text_encoder[0].device
dtype = text_encoder[0].dtype
batch_size = len(prompts)
# clip
text_inputs = tokenizer[0](
prompts,
padding="max_length",
max_length=tokenizer[0].model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
pooled_prompt_embeds = prompt_embeds.pooler_output
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
# T5
text_inputs = tokenizer[1](
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
dtype = text_encoder[1].dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
# prompt_embeds = prompt_embeds * prompt_attention_mask
# _, seq_len, _ = prompt_embeds.shape
# they dont do prompt attention mask?
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
return prompt_embeds, pooled_prompt_embeds
# for XL
def get_add_time_ids(