Added flux training. Still a WIP. Wont train right without rectified flow working right

This commit is contained in:
Jaret Burkett
2024-08-02 15:00:30 -06:00
parent 03613c523f
commit 87ba867fdc
6 changed files with 292 additions and 15 deletions

View File

@@ -1234,7 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
if self.train_config.gradient_checkpointing: if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing() if self.sd.is_flux:
unet.gradient_checkpointing = True
else:
unet.enable_gradient_checkpointing()
if isinstance(text_encoder, list): if isinstance(text_encoder, list):
for te in text_encoder: for te in text_encoder:
if hasattr(te, 'enable_gradient_checkpointing'): if hasattr(te, 'enable_gradient_checkpointing'):
@@ -1325,6 +1328,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_v3=self.model_config.is_v3, is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart, is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow, is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
is_ssd=self.model_config.is_ssd, is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega, is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout, dropout=self.network_config.dropout,

View File

@@ -367,6 +367,7 @@ class ModelConfig:
self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False) self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False)
self.is_auraflow: bool = kwargs.get('is_auraflow', False) self.is_auraflow: bool = kwargs.get('is_auraflow', False)
self.is_v3: bool = kwargs.get('is_v3', False) self.is_v3: bool = kwargs.get('is_v3', False)
self.is_flux: bool = kwargs.get('is_flux', False)
if self.is_pixart_sigma: if self.is_pixart_sigma:
self.is_pixart = True self.is_pixart = True
self.is_ssd: bool = kwargs.get('is_ssd', False) self.is_ssd: bool = kwargs.get('is_ssd', False)
@@ -404,6 +405,9 @@ class ModelConfig:
self.vae_dtype = kwargs.get("vae_dtype", self.dtype) self.vae_dtype = kwargs.get("vae_dtype", self.dtype)
self.te_device = kwargs.get("te_device", None) self.te_device = kwargs.get("te_device", None)
self.te_dtype = kwargs.get("te_dtype", self.dtype) self.te_dtype = kwargs.get("te_dtype", self.dtype)
# only for flux for now
self.quantize = kwargs.get("quantize", False)
pass pass

View File

@@ -1361,6 +1361,8 @@ class LatentCachingMixin:
file_item.latent_space_version = 'sd3' file_item.latent_space_version = 'sd3'
elif self.sd.is_auraflow: elif self.sd.is_auraflow:
file_item.latent_space_version = 'sdxl' file_item.latent_space_version = 'sdxl'
elif self.sd.is_flux:
file_item.latent_space_version = 'flux'
elif self.sd.model_config.is_pixart_sigma: elif self.sd.model_config.is_pixart_sigma:
file_item.latent_space_version = 'sdxl' file_item.latent_space_version = 'sdxl'
else: else:

View File

@@ -159,6 +159,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_v3=False, is_v3=False,
is_pixart: bool = False, is_pixart: bool = False,
is_auraflow: bool = False, is_auraflow: bool = False,
is_flux: bool = False,
use_bias: bool = False, use_bias: bool = False,
is_lorm: bool = False, is_lorm: bool = False,
ignore_if_contains = None, ignore_if_contains = None,
@@ -216,6 +217,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_v3 = is_v3 self.is_v3 = is_v3
self.is_pixart = is_pixart self.is_pixart = is_pixart
self.is_auraflow = is_auraflow self.is_auraflow = is_auraflow
self.is_flux = is_flux
self.network_type = network_type self.network_type = network_type
if self.network_type.lower() == "dora": if self.network_type.lower() == "dora":
self.module_class = DoRAModule self.module_class = DoRAModule
@@ -250,7 +252,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_replace_modules: List[torch.nn.Module], target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]: ) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET unet_prefix = self.LORA_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow: if is_pixart or is_v3 or is_auraflow or is_flux:
unet_prefix = f"lora_transformer" unet_prefix = f"lora_transformer"
prefix = ( prefix = (
@@ -293,6 +295,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.transformer_only and self.is_pixart and is_unet: if self.transformer_only and self.is_pixart and is_unet:
if "transformer_blocks" not in lora_name: if "transformer_blocks" not in lora_name:
skip = True skip = True
if self.transformer_only and self.is_flux and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip: if (is_linear or is_conv2d) and not skip:
@@ -393,6 +398,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if is_auraflow: if is_auraflow:
target_modules = ["AuraFlowTransformer2DModel"] target_modules = ["AuraFlowTransformer2DModel"]
if is_flux:
target_modules = ["FluxTransformer2DModel"]
if train_unet: if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else: else:
@@ -454,7 +462,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out: if self.full_train_in_out:
if self.is_pixart or self.is_auraflow: if self.is_pixart or self.is_auraflow or self.is_flux:
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else: else:

View File

@@ -41,17 +41,21 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler
import diffusers import diffusers
from diffusers import \ from diffusers import \
AutoencoderKL, \ AutoencoderKL, \
UNet2DConditionModel UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance from toolkit.util.inverse_cfg import inverse_classifier_guidance
from optimum.quanto import freeze, qfloat8, quantize
# tell it to shut up # tell it to shut up
diffusers.logging.set_verbosity(diffusers.logging.ERROR) diffusers.logging.set_verbosity(diffusers.logging.ERROR)
@@ -78,6 +82,7 @@ DO_NOT_TRAIN_WEIGHTS = [
DeviceStatePreset = Literal['cache_latents', 'generate'] DeviceStatePreset = Literal['cache_latents', 'generate']
class BlankNetwork: class BlankNetwork:
def __init__(self): def __init__(self):
@@ -101,10 +106,6 @@ def flush():
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 # VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
# if is type checking
if typing.TYPE_CHECKING:
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
class StableDiffusion: class StableDiffusion:
@@ -158,6 +159,7 @@ class StableDiffusion:
self.is_vega = model_config.is_vega self.is_vega = model_config.is_vega
self.is_pixart = model_config.is_pixart self.is_pixart = model_config.is_pixart
self.is_auraflow = model_config.is_auraflow self.is_auraflow = model_config.is_auraflow
self.is_flux = model_config.is_flux
self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2 self.use_text_encoder_2 = model_config.use_text_encoder_2
@@ -443,6 +445,71 @@ class StableDiffusion:
text_encoder.eval() text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
tokenizer = pipe.tokenizer tokenizer = pipe.tokenizer
elif self.model_config.is_flux:
print("Loading Flux model")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae", torch_dtype=dtype)
flush()
print("Loading transformer")
transformer = FluxTransformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype)
transformer.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize:
print("Quantizing transformer")
quantize(transformer, weights=qfloat8)
freeze(transformer)
flush()
print("Loading t5")
text_encoder_2 = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer_2 = T5TokenizerFast.from_pretrained(model_path, subfolder="tokenizer_2", torch_dtype=dtype)
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize:
print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
flush()
print("Loading clip")
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
print("making pipe")
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
print("preparing")
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
text_encoder[1].to(self.device_torch)
text_encoder[1].requires_grad_(False)
text_encoder[1].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
else: else:
if self.custom_pipeline is not None: if self.custom_pipeline is not None:
pipln = self.custom_pipeline pipln = self.custom_pipeline
@@ -515,7 +582,7 @@ class StableDiffusion:
# add hacks to unet to help training # add hacks to unet to help training
# pipe.unet = prepare_unet_for_training(pipe.unet) # pipe.unet = prepare_unet_for_training(pipe.unet)
if self.is_pixart or self.is_v3 or self.is_auraflow: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
# pixart and sd3 dont use a unet # pixart and sd3 dont use a unet
self.unet = pipe.transformer self.unet = pipe.transformer
else: else:
@@ -695,6 +762,18 @@ class StableDiffusion:
**extra_args **extra_args
).to(self.device_torch) ).to(self.device_torch)
pipeline.watermark = None pipeline.watermark = None
elif self.is_flux:
pipeline = FluxPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder[0],
text_encoder_2=self.text_encoder[1],
tokenizer=self.tokenizer[0],
tokenizer_2=self.tokenizer[1],
scheduler=noise_scheduler,
**extra_args
).to(self.device_torch)
pipeline.watermark = None
elif self.is_v3: elif self.is_v3:
pipeline = Pipe( pipeline = Pipe(
vae=self.vae, vae=self.vae,
@@ -954,6 +1033,19 @@ class StableDiffusion:
latents=gen_config.latents, latents=gen_config.latents,
**extra **extra
).images[0] ).images[0]
elif self.is_flux:
img = pipeline(
prompt_embeds=conditional_embeds.text_embeds,
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
# negative_prompt_embeds=unconditional_embeds.text_embeds,
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
**extra
).images[0]
elif self.is_pixart: elif self.is_pixart:
# needs attention masks for some reason # needs attention masks for some reason
img = pipeline( img = pipeline(
@@ -1073,10 +1165,14 @@ class StableDiffusion:
if width is None: if width is None:
width = pixel_width // VAE_SCALE_FACTOR width = pixel_width // VAE_SCALE_FACTOR
num_channels = self.unet.config['in_channels']
if self.is_flux:
# has 64 channels in for some reason
num_channels = 16
noise = torch.randn( noise = torch.randn(
( (
batch_size, batch_size,
self.unet.config['in_channels'], num_channels,
height, height,
width, width,
), ),
@@ -1429,7 +1525,88 @@ class StableDiffusion:
self.unet.to(self.device_torch) self.unet.to(self.device_torch)
if self.unet.dtype != self.torch_dtype: if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype) self.unet = self.unet.to(dtype=self.torch_dtype)
if self.is_v3: if self.is_flux:
with torch.no_grad():
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract
# this is what diffusers does
text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to(
device=self.device_torch, dtype=self.text_encoder[0].dtype
)
# todo check these
# height = latent_model_input.shape[2] * VAE_SCALE_FACTOR
# width = latent_model_input.shape[3] * VAE_SCALE_FACTOR
height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128
width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128
width_latent = latent_model_input.shape[3]
height_latent = latent_model_input.shape[2]
latent_image_ids = self.pipeline._prepare_latent_image_ids(
batch_size=latent_model_input.shape[0],
height=height_latent,
width=width_latent,
device=self.device_torch,
dtype=self.torch_dtype,
)
# # handle guidance
guidance_scale = 1.0 # ?
if self.unet.config.guidance_embeds:
guidance = torch.tensor([guidance_scale], device=self.device_torch)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# not sure how to handle this
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# image_seq_len = latents.shape[1]
# mu = calculate_shift(
# image_seq_len,
# self.scheduler.config.base_image_seq_len,
# self.scheduler.config.max_image_seq_len,
# self.scheduler.config.base_shift,
# self.scheduler.config.max_shift,
# )
# timesteps, num_inference_steps = retrieve_timesteps(
# self.scheduler,
# num_inference_steps,
# device,
# timesteps,
# sigmas,
# mu=mu,
# )
latent_model_input = self.pipeline._pack_latents(
latent_model_input,
batch_size=latent_model_input.shape[0],
num_channels_latents=latent_model_input.shape[1], # 16
height=height_latent, # 128
width=width_latent, # 128
)
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), # [1, 768]
txt_ids=text_ids, # [1, 512, 3]
img_ids=latent_image_ids, # [1, 4096, 3]
guidance=guidance,
return_dict=False,
**kwargs,
)[0]
# unpack latents
noise_pred = self.pipeline._unpack_latents(
noise_pred,
height=height, # 1024
width=height, # 1024
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
)
elif self.is_v3:
noise_pred = self.unet( noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep, timestep=timestep,
@@ -1656,6 +1833,21 @@ class StableDiffusion:
embeds, embeds,
attention_mask=attention_mask, # not used attention_mask=attention_mask, # not used
) )
elif self.is_flux:
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
self.tokenizer, # list
self.text_encoder, # list
prompt,
truncate=not long_prompts,
max_length=512,
dropout_prob=dropout_prob
)
pe = PromptEmbeds(
prompt_embeds
)
pe.pooled_embeds = pooled_prompt_embeds
return pe
elif isinstance(self.text_encoder, T5EncoderModel): elif isinstance(self.text_encoder, T5EncoderModel):
embeds, attention_mask = train_tools.encode_prompts_pixart( embeds, attention_mask = train_tools.encode_prompts_pixart(
@@ -1989,7 +2181,7 @@ class StableDiffusion:
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
unet_lr = unet_lr if unet_lr is not None else default_lr unet_lr = unet_lr if unet_lr is not None else default_lr
params = [] params = []
if self.is_pixart or self.is_auraflow: if self.is_pixart or self.is_auraflow or self.is_flux:
for param in named_params.values(): for param in named_params.values():
if param.requires_grad: if param.requires_grad:
params.append(param) params.append(param)
@@ -2035,7 +2227,7 @@ class StableDiffusion:
def save_device_state(self): def save_device_state(self):
# saves the current device state for all modules # saves the current device state for all modules
# this is useful for when we want to alter the state and restore it # this is useful for when we want to alter the state and restore it
if self.is_pixart or self.is_v3 or self.is_auraflow: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
unet_has_grad = self.unet.proj_out.weight.requires_grad unet_has_grad = self.unet.proj_out.weight.requires_grad
else: else:
unet_has_grad = self.unet.conv_in.weight.requires_grad unet_has_grad = self.unet.conv_in.weight.requires_grad

View File

@@ -3,7 +3,7 @@ import hashlib
import json import json
import os import os
import time import time
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union, List
import sys import sys
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@@ -766,6 +766,73 @@ def encode_prompts_auraflow(
return prompt_embeds, prompt_attention_mask return prompt_embeds, prompt_attention_mask
def encode_prompts_flux(
tokenizer: List[Union['CLIPTokenizer','T5Tokenizer']],
text_encoder: List[Union['CLIPTextModel', 'T5EncoderModel']],
prompts: list[str],
truncate: bool = True,
max_length=None,
dropout_prob=0.0,
):
if max_length is None:
max_length = 512
if dropout_prob > 0.0:
# randomly drop out prompts
prompts = [
prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts
]
device = text_encoder[0].device
dtype = text_encoder[0].dtype
batch_size = len(prompts)
# clip
text_inputs = tokenizer[0](
prompts,
padding="max_length",
max_length=tokenizer[0].model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder[0](text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
pooled_prompt_embeds = prompt_embeds.pooler_output
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
# T5
text_inputs = tokenizer[1](
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder[1](text_input_ids.to(device), output_hidden_states=False)[0]
dtype = text_encoder[1].dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape)
# prompt_embeds = prompt_embeds * prompt_attention_mask
# _, seq_len, _ = prompt_embeds.shape
# they dont do prompt attention mask?
# prompt_attention_mask = torch.ones((batch_size, seq_len), dtype=dtype, device=device)
return prompt_embeds, pooled_prompt_embeds
# for XL # for XL
def get_add_time_ids( def get_add_time_ids(