diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3c85de62..d5a2d95e 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -883,6 +883,26 @@ class SDTrainer(BaseSDTrainProcess): def end_of_training_loop(self): pass + def predict_noise( + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + **kwargs, + ): + dtype = get_torch_dtype(self.train_config.dtype) + return self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, + **kwargs + ) + def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): self.timer.start('preprocess_batch') batch = self.preprocess_batch(batch) @@ -1453,14 +1473,11 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('predict_unet'): if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() - noise_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), - unconditional_embeddings=unconditional_embeds, - timestep=timesteps, - guidance_scale=self.train_config.cfg_scale, - detach_unconditional=False, - rescale_cfg=self.train_config.cfg_rescale, + noise_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, **pred_kwargs ) self.after_unet_predict() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 78280288..b12f7e7c 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1287,6 +1287,7 @@ class BaseSDTrainProcess(BaseTrainProcess): is_v2=self.model_config.is_v2, is_v3=self.model_config.is_v3, is_pixart=self.model_config.is_pixart, + is_auraflow=self.model_config.is_auraflow, is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, diff --git a/toolkit/buckets.py b/toolkit/buckets.py index be69085b..aec99040 100644 --- a/toolkit/buckets.py +++ b/toolkit/buckets.py @@ -53,6 +53,50 @@ resolutions_1024: List[BucketResolution] = [ {"width": 512, "height": 2048}, ] +# Even numbers so they can be patched easier +resolutions_dit_1024: List[BucketResolution] = [ + # Base resolution + {"width": 1024, "height": 1024}, + # widescreen + {"width": 2048, "height": 512}, + {"width": 1792, "height": 576}, + {"width": 1728, "height": 576}, + {"width": 1664, "height": 576}, + {"width": 1600, "height": 640}, + {"width": 1536, "height": 640}, + {"width": 1472, "height": 704}, + {"width": 1408, "height": 704}, + {"width": 1344, "height": 704}, + {"width": 1344, "height": 768}, + {"width": 1280, "height": 768}, + {"width": 1216, "height": 832}, + {"width": 1152, "height": 832}, + {"width": 1152, "height": 896}, + {"width": 1088, "height": 896}, + {"width": 1088, "height": 960}, + {"width": 1024, "height": 960}, + # portrait + {"width": 960, "height": 1024}, + {"width": 960, "height": 1088}, + {"width": 896, "height": 1088}, + {"width": 896, "height": 1152}, # 2:3 + {"width": 832, "height": 1152}, + {"width": 832, "height": 1216}, + {"width": 768, "height": 1280}, + {"width": 768, "height": 1344}, + {"width": 704, "height": 1408}, + {"width": 704, "height": 1472}, + {"width": 640, "height": 1536}, + {"width": 640, "height": 1600}, + {"width": 576, "height": 1664}, + {"width": 576, "height": 1728}, + {"width": 576, "height": 1792}, + {"width": 512, "height": 1856}, + {"width": 512, "height": 1920}, + {"width": 512, "height": 1984}, + {"width": 512, "height": 2048}, +] + def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]: # determine scaler form 1024 to resolution diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 5edc3fb7..f05b8b3e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -350,6 +350,7 @@ class ModelConfig: self.is_xl: bool = kwargs.get('is_xl', False) self.is_pixart: bool = kwargs.get('is_pixart', False) self.is_pixart_sigma: bool = kwargs.get('is_pixart_sigma', False) + self.is_auraflow: bool = kwargs.get('is_auraflow', False) self.is_v3: bool = kwargs.get('is_v3', False) if self.is_pixart_sigma: self.is_pixart = True @@ -381,7 +382,7 @@ class ModelConfig: self.is_xl = True # for text encoder quant. Only works with pixart currently - self.text_encoder_bits = kwargs.get('text_encoder_bits', 8) # 16, 8, 4 + self.text_encoder_bits = kwargs.get('text_encoder_bits', 16) # 16, 8, 4 self.unet_path = kwargs.get("unet_path", None) self.unet_sample_size = kwargs.get("unet_sample_size", None) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index ac96a3a2..b52dfc53 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1355,6 +1355,10 @@ class LatentCachingMixin: file_item.latent_space_version = 'sdxl' elif self.sd.is_v3: file_item.latent_space_version = 'sd3' + elif self.sd.is_auraflow: + file_item.latent_space_version = 'sdxl' + elif self.sd.model_config.is_pixart_sigma: + file_item.latent_space_version = 'sdxl' else: file_item.latent_space_version = 'sd1' file_item.is_caching_to_disk = to_disk diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index a71a71b7..0e46e4ff 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -7,7 +7,7 @@ import re import sys from typing import List, Optional, Dict, Type, Union import torch -from diffusers import UNet2DConditionModel, PixArtTransformer2DModel +from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel from transformers import CLIPTextModel from .config_modules import NetworkConfig @@ -158,6 +158,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_v2=False, is_v3=False, is_pixart: bool = False, + is_auraflow: bool = False, use_bias: bool = False, is_lorm: bool = False, ignore_if_contains = None, @@ -212,6 +213,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_v2 = is_v2 self.is_v3 = is_v3 self.is_pixart = is_pixart + self.is_auraflow = is_auraflow self.network_type = network_type if self.network_type.lower() == "dora": self.module_class = DoRAModule @@ -246,7 +248,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: unet_prefix = self.LORA_PREFIX_UNET - if is_pixart or is_v3: + if is_pixart or is_v3 or is_auraflow: unet_prefix = f"lora_transformer" prefix = ( @@ -371,6 +373,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if is_pixart: target_modules = ["PixArtTransformer2DModel"] + if is_auraflow: + target_modules = ["AuraFlowTransformer2DModel"] + if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) else: @@ -408,6 +413,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): transformer.pos_embed = self.transformer_pos_embed transformer.proj_out = self.transformer_proj_out + elif self.is_auraflow: + transformer: AuraFlowTransformer2DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.pos_embed = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out + else: unet: UNet2DConditionModel = unet unet_conv_in: torch.nn.Conv2d = unet.conv_in @@ -424,7 +437,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) if self.full_train_in_out: - if self.is_pixart: + if self.is_pixart or self.is_auraflow: all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) else: diff --git a/toolkit/models/auraflow.py b/toolkit/models/auraflow.py new file mode 100644 index 00000000..e2539bda --- /dev/null +++ b/toolkit/models/auraflow.py @@ -0,0 +1,127 @@ +import math +from functools import partial + +from torch import nn +import torch + + +class AuraFlowPatchEmbed(nn.Module): + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + pos_embed_max_size=None, + ): + super().__init__() + + self.num_patches = (height // patch_size) * (width // patch_size) + self.pos_embed_max_size = pos_embed_max_size + + self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) + self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) + + self.patch_size = patch_size + self.height, self.width = height // patch_size, width // patch_size + self.base_size = height // patch_size + + def forward(self, latent): + batch_size, num_channels, height, width = latent.size() + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + try: + return latent + self.pos_embed + except RuntimeError: + raise RuntimeError( + f"Positional embeddings are too small for the number of patches. " + f"Please increase `pos_embed_max_size` to at least {self.num_patches}." + ) + + +# comfy +# def apply_pos_embeds(self, x, h, w): +# h = (h + 1) // self.patch_size +# w = (w + 1) // self.patch_size +# max_dim = max(h, w) +# +# cur_dim = self.h_max +# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) +# +# if max_dim > cur_dim: +# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, +# -1) +# cur_dim = max_dim +# +# from_h = (cur_dim - h) // 2 +# from_w = (cur_dim - w) // 2 +# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w] +# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) + + # def patchify(self, x): + # B, C, H, W = x.size() + # pad_h = (self.patch_size - H % self.patch_size) % self.patch_size + # pad_w = (self.patch_size - W % self.patch_size) % self.patch_size + # + # x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') + # x = x.view( + # B, + # C, + # (H + 1) // self.patch_size, + # self.patch_size, + # (W + 1) // self.patch_size, + # self.patch_size, + # ) + # x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + # return x + +def patch_auraflow_pos_embed(pos_embed): + # we need to hijack the forward and replace with a custom one. Self is the model + def new_forward(self, latent): + batch_size, num_channels, height, width = latent.size() + + # add padding to the latent to make it match pos_embed + latent_size = height * width * num_channels / 16 # todo check where 16 comes from? + pos_embed_size = self.pos_embed.shape[1] + if latent_size < pos_embed_size: + total_padding = int(pos_embed_size - math.floor(latent_size)) + total_padding = total_padding // 16 + pad_height = total_padding // 2 + pad_width = total_padding - pad_height + # mirror padding on the right side + padding = (0, pad_width, 0, pad_height) + latent = torch.nn.functional.pad(latent, padding, mode='reflect') + elif latent_size > pos_embed_size: + amount_to_remove = latent_size - pos_embed_size + latent = latent[:, :, :-amount_to_remove] + + batch_size, num_channels, height, width = latent.size() + + latent = latent.view( + batch_size, + num_channels, + height // self.patch_size, + self.patch_size, + width // self.patch_size, + self.patch_size, + ) + latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) + latent = self.proj(latent) + try: + return latent + self.pos_embed + except RuntimeError: + raise RuntimeError( + f"Positional embeddings are too small for the number of patches. " + f"Please increase `pos_embed_max_size` to at least {self.num_patches}." + ) + + pos_embed.forward = partial(new_forward, pos_embed) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 08880a82..6ee84a89 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -27,6 +27,7 @@ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encod from toolkit import train_tools from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.auraflow import patch_auraflow_pos_embed from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds from toolkit.reference_adapter import ReferenceAdapter @@ -40,13 +41,14 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ - StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ + StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel import diffusers from diffusers import \ AutoencoderKL, \ UNet2DConditionModel from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline -from transformers import T5EncoderModel, BitsAndBytesConfig +from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from toolkit.util.inverse_cfg import inverse_classifier_guidance @@ -149,6 +151,7 @@ class StableDiffusion: self.is_v3 = model_config.is_v3 self.is_vega = model_config.is_vega self.is_pixart = model_config.is_pixart + self.is_auraflow = model_config.is_auraflow self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 @@ -371,6 +374,68 @@ class StableDiffusion: text_encoder.eval() pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) tokenizer = pipe.tokenizer + + + elif self.model_config.is_auraflow: + te_kwargs = {} + # handle quantization of TE + te_is_quantized = False + if self.model_config.text_encoder_bits == 8: + te_kwargs['load_in_8bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + elif self.model_config.text_encoder_bits == 4: + te_kwargs['load_in_4bit'] = True + te_kwargs['device_map'] = "auto" + te_is_quantized = True + + main_model_path = model_path + + # load the TE in 8bit mode + text_encoder = UMT5EncoderModel.from_pretrained( + main_model_path, + subfolder="text_encoder", + torch_dtype=self.torch_dtype, + **te_kwargs + ) + + # load the transformer + subfolder = "transformer" + # check if it is just the unet + if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)): + subfolder = None + + if te_is_quantized: + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + + # load the transformer only from the save + transformer = AuraFlowTransformer2DModel.from_pretrained( + model_path if self.model_config.unet_path is None else self.model_config.unet_path, + torch_dtype=self.torch_dtype, + subfolder='transformer' + ) + pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained( + main_model_path, + transformer=transformer, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ) + + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + + # patch auraflow so it can handle other aspect ratios + patch_auraflow_pos_embed(pipe.transformer.pos_embed) + + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer else: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -418,7 +483,7 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - if self.is_pixart or self.is_v3: + if self.is_pixart or self.is_v3 or self.is_auraflow: # pixart and sd3 dont use a unet self.unet = pipe.transformer else: @@ -621,6 +686,16 @@ class StableDiffusion: **extra_args ) + elif self.is_auraflow: + pipeline = AuraFlowPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) + else: pipeline = Pipe( vae=self.vae, @@ -846,6 +921,24 @@ class StableDiffusion: ).images[0] elif self.is_pixart: # needs attention masks for some reason + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] + elif self.is_auraflow: + pipeline: AuraFlowPipeline = pipeline + img = pipeline( prompt=None, prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), @@ -1309,6 +1402,18 @@ class StableDiffusion: **kwargs, ).sample noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep) + elif self.is_auraflow: + # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0]) + t = t.to(self.device_torch, self.torch_dtype) + + noise_pred = self.unet( + latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + timestep=t, + return_dict=False, + )[0] else: noise_pred = self.unet( latent_model_input.to(self.device_torch, self.torch_dtype), @@ -1502,6 +1607,19 @@ class StableDiffusion: embeds, attention_mask=attention_mask, ) + elif self.is_auraflow: + embeds, attention_mask = train_tools.encode_prompts_auraflow( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=256, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, # not used + ) elif isinstance(self.text_encoder, T5EncoderModel): embeds, attention_mask = train_tools.encode_prompts_pixart( @@ -1835,7 +1953,7 @@ class StableDiffusion: named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] - if self.is_pixart: + if self.is_pixart or self.is_auraflow: for param in named_params.values(): if param.requires_grad: params.append(param) @@ -1881,7 +1999,7 @@ class StableDiffusion: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_pixart or self.is_v3: + if self.is_pixart or self.is_v3 or self.is_auraflow: unet_has_grad = self.unet.proj_out.weight.requires_grad else: unet_has_grad = self.unet.conv_in.weight.requires_grad @@ -1912,7 +2030,7 @@ class StableDiffusion: 'requires_grad': te_has_grad }) else: - if isinstance(self.text_encoder, T5EncoderModel): + if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad else: te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index db15e63f..7d492441 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -30,7 +30,7 @@ from diffusers import ( from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import torch import re -from transformers import T5Tokenizer, T5EncoderModel +from transformers import T5Tokenizer, T5EncoderModel, UMT5EncoderModel SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -725,6 +725,48 @@ def encode_prompts_pixart( return prompt_embeds.last_hidden_state, prompt_attention_mask +def encode_prompts_auraflow( + tokenizer: 'T5Tokenizer', + text_encoder: 'UMT5EncoderModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + max_length = 256 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + device = text_encoder.device + + text_inputs = tokenizer( + prompts, + truncation=True, + max_length=max_length, + padding="max_length", + return_tensors="pt", + ) + text_inputs = {k: v.to(device) for k, v in text_inputs.items()} + text_input_ids = text_inputs["input_ids"] + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + + prompt_embeds = text_encoder(**text_inputs)[0] + prompt_attention_mask = text_inputs["attention_mask"].unsqueeze(-1).expand(prompt_embeds.shape) + prompt_embeds = prompt_embeds * prompt_attention_mask + + return prompt_embeds, prompt_attention_mask + + # for XL def get_add_time_ids( height: int,