From 93b52932c17b8e963db4e4cac73eb271036f0dce Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 13 Feb 2024 16:00:04 -0700 Subject: [PATCH] Added training for pixart-a --- extensions_built_in/sd_trainer/SDTrainer.py | 6 +- jobs/process/BaseSDTrainProcess.py | 3 +- toolkit/config_modules.py | 4 + toolkit/custom_adapter.py | 38 ++++- toolkit/lora_special.py | 8 +- toolkit/prompt_utils.py | 21 ++- toolkit/sampler.py | 34 ++++- toolkit/saving.py | 2 + toolkit/stable_diffusion_model.py | 153 ++++++++++++++++++-- toolkit/train_tools.py | 43 ++++++ 10 files changed, 288 insertions(+), 24 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3997395e..ab257209 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -225,6 +225,9 @@ class SDTrainer(BaseSDTrainProcess): noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) + if self.train_config.pred_scaler != 1.0: + noise_pred = noise_pred * self.train_config.pred_scaler + target = None if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): if self.train_config.correct_pred_norm and not is_reg: @@ -343,7 +346,8 @@ class SDTrainer(BaseSDTrainProcess): print("Prior loss is nan") prior_loss = None else: - prior_loss = prior_loss.mean([1, 2, 3]) + # prior_loss = prior_loss.mean([1, 2, 3]) + loss = loss + prior_loss # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) if prior_loss is not None: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 70ae28b2..37834a4d 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1054,7 +1054,8 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.noise_scheduler, { "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", - } + }, + 'sd' if not self.model_config.is_pixart else 'pixart' ) if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 675ccbd8..1768697d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -304,12 +304,16 @@ class TrainConfig: self.loss_type = kwargs.get('loss_type', 'mse') + # scale the prediction by this. Increase for more detail, decrease for less + self.pred_scaler = kwargs.get('pred_scaler', 1.0) + class ModelConfig: def __init__(self, **kwargs): self.name_or_path: str = kwargs.get('name_or_path', None) self.is_v2: bool = kwargs.get('is_v2', False) self.is_xl: bool = kwargs.get('is_xl', False) + self.is_pixart: bool = kwargs.get('is_pixart', False) self.is_ssd: bool = kwargs.get('is_ssd', False) self.is_vega: bool = kwargs.get('is_vega', False) self.is_v_pred: bool = kwargs.get('is_v_pred', False) diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 7ec235bf..694869d8 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -13,7 +13,7 @@ from toolkit.models.te_adapter import TEAdapter from toolkit.models.vd_adapter import VisionDirectAdapter from toolkit.paths import REPOS_ROOT from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder -from toolkit.saving import load_ip_adapter_model +from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) @@ -99,6 +99,13 @@ class CustomAdapter(torch.nn.Module): tokenizer.add_tokens([self.flag_word], special_tokens=True) else: self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True) + elif self.config.name_or_path is not None: + loaded_state_dict = load_custom_adapter_model( + self.config.name_or_path, + self.sd_ref().device, + dtype=self.sd_ref().dtype, + ) + self.load_state_dict(loaded_state_dict, strict=False) def setup_adapter(self): if self.adapter_type == 'photo_maker': @@ -287,6 +294,9 @@ class CustomAdapter(torch.nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): strict = False + if self.config.train_only_image_encoder and 'vd_adapter' not in state_dict and 'dvadapter' not in state_dict: + # we are loading pure clip weights. + self.vision_encoder.load_state_dict(state_dict, strict=strict) if 'lora_weights' in state_dict: # todo add LoRA @@ -332,6 +342,8 @@ class CustomAdapter(torch.nn.Module): if 'vd_adapter' in state_dict: self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict) + if 'dvadapter' in state_dict: + self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict) if 'vision_encoder' in state_dict and self.config.train_image_encoder: self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) @@ -346,6 +358,9 @@ class CustomAdapter(torch.nn.Module): def state_dict(self) -> OrderedDict: state_dict = OrderedDict() + if self.config.train_only_image_encoder: + return self.vision_encoder.state_dict() + if self.adapter_type == 'photo_maker': if self.config.train_image_encoder: state_dict["id_encoder"] = self.vision_encoder.state_dict() @@ -364,7 +379,9 @@ class CustomAdapter(torch.nn.Module): state_dict["te_adapter"] = self.te_adapter.state_dict() return state_dict elif self.adapter_type == 'vision_direct': - state_dict["vd_adapter"] = self.vd_adapter.state_dict() + state_dict["dvadapter"] = self.vd_adapter.state_dict() + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() return state_dict elif self.adapter_type == 'ilora': if self.config.train_image_encoder: @@ -617,6 +634,12 @@ class CustomAdapter(torch.nn.Module): clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1]) return clip_image.detach() + def train(self, mode: bool = True): + if self.config.train_image_encoder: + self.vision_encoder.train(mode) + else: + super().train(mode) + def trigger_pre_te( self, tensors_0_1: torch.Tensor, @@ -735,6 +758,9 @@ class CustomAdapter(torch.nn.Module): self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + if self.config.train_only_image_encoder: + yield from self.vision_encoder.parameters(recurse) + return if self.config.type == 'photo_maker': yield from self.fuse_module.parameters(recurse) if self.config.train_image_encoder: @@ -753,5 +779,13 @@ class CustomAdapter(torch.nn.Module): elif self.config.type == 'vision_direct': for attn_processor in self.vd_adapter.adapter_modules: yield from attn_processor.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) else: raise NotImplementedError + + def enable_gradient_checkpointing(self): + if hasattr(self.vision_encoder, "enable_gradient_checkpointing"): + self.vision_encoder.enable_gradient_checkpointing() + elif hasattr(self.vision_encoder, 'gradient_checkpointing'): + self.vision_encoder.gradient_checkpointing = True diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 2fd18ab5..744111bf 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -151,6 +151,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, + is_pixart: bool = False, use_bias: bool = False, is_lorm: bool = False, ignore_if_contains = None, @@ -197,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.multiplier = multiplier self.is_sdxl = is_sdxl self.is_v2 = is_v2 + self.is_pixart = is_pixart if modules_dim is not None: print(f"create LoRA network from weights") @@ -224,8 +226,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): root_module: torch.nn.Module, target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: + unet_prefix = self.LORA_PREFIX_UNET + if is_pixart: + unet_prefix = f"lora_transformer" + prefix = ( - self.LORA_PREFIX_UNET + unet_prefix if is_unet else ( self.LORA_PREFIX_TEXT_ENCODER diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 4a132967..a145841c 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -19,10 +19,11 @@ class ACTION_TYPES_SLIDER: class PromptEmbeds: - text_embeds: torch.Tensor - pooled_embeds: Union[torch.Tensor, None] + # text_embeds: torch.Tensor + # pooled_embeds: Union[torch.Tensor, None] + # attention_mask: Union[torch.Tensor, None] - def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None: + def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None: if isinstance(args, list) or isinstance(args, tuple): # xl self.text_embeds = args[0] @@ -32,10 +33,14 @@ class PromptEmbeds: self.text_embeds = args self.pooled_embeds = None + self.attention_mask = attention_mask + def to(self, *args, **kwargs): self.text_embeds = self.text_embeds.to(*args, **kwargs) if self.pooled_embeds is not None: self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) + if self.attention_mask is not None: + self.attention_mask = self.attention_mask.to(*args, **kwargs) return self def detach(self): @@ -43,13 +48,19 @@ class PromptEmbeds: new_embeds.text_embeds = new_embeds.text_embeds.detach() if new_embeds.pooled_embeds is not None: new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() + if new_embeds.attention_mask is not None: + new_embeds.attention_mask = new_embeds.attention_mask.detach() return new_embeds def clone(self): if self.pooled_embeds is not None: - return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()]) + prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()]) else: - return PromptEmbeds(self.text_embeds.clone()) + prompt_embeds = PromptEmbeds(self.text_embeds.clone()) + + if self.attention_mask is not None: + prompt_embeds.attention_mask = self.attention_mask.clone() + return prompt_embeds class EncodedPromptPair: diff --git a/toolkit/sampler.py b/toolkit/sampler.py index 2e9f1654..dae8b99b 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -1,4 +1,5 @@ import copy +import math from diffusers import ( DDPMScheduler, @@ -25,7 +26,7 @@ SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" -sdxl_sampler_config = { +sd_config = { "_class_name": "EulerAncestralDiscreteScheduler", "_diffusers_version": "0.24.0.dev0", "beta_end": 0.012, @@ -43,15 +44,44 @@ sdxl_sampler_config = { "trained_betas": None } +pixart_config = { + "_class_name": "DPMSolverMultistepScheduler", + "_diffusers_version": "0.22.0.dev0", + "algorithm_type": "dpmsolver++", + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "dynamic_thresholding_ratio": 0.995, + "euler_at_final": False, + # "lambda_min_clipped": -Infinity, + "lambda_min_clipped": -math.inf, + "lower_order_final": True, + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "solver_order": 2, + "solver_type": "midpoint", + "steps_offset": 0, + "thresholding": False, + "timestep_spacing": "linspace", + "trained_betas": None, + "use_karras_sigmas": False, + "use_lu_lambdas": False, + "variance_type": None +} + def get_sampler( sampler: str, kwargs: dict = None, + arch: str = "sd" ): sched_init_args = {} if kwargs is not None: sched_init_args.update(kwargs) + config_to_use = copy.deepcopy(sd_config) if arch == "sd" else copy.deepcopy(pixart_config) + if sampler.startswith("k_"): sched_init_args["use_karras_sigmas"] = True @@ -83,7 +113,7 @@ def get_sampler( elif sampler == "custom_lcm": scheduler_cls = CustomLCMScheduler - config = copy.deepcopy(sdxl_sampler_config) + config = copy.deepcopy(config_to_use) config.update(sched_init_args) scheduler = scheduler_cls.from_config(config) diff --git a/toolkit/saving.py b/toolkit/saving.py index e2a90abb..7abc7d50 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -263,6 +263,8 @@ def load_custom_adapter_model( if path_to_file.endswith('.safetensors'): raw_state_dict = load_file(path_to_file, device) combined_state_dict = OrderedDict() + device = device if isinstance(device, torch.device) else torch.device(device) + dtype = dtype if isinstance(dtype, torch.dtype) else get_torch_dtype(dtype) for combo_key, value in raw_state_dict.items(): key_split = combo_key.split('.') module_name = key_split.pop(0) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index bd0056a2..2fc4cc5e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -11,6 +11,7 @@ from collections import OrderedDict import yaml from PIL import Image +from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file from torch.nn import Parameter @@ -43,6 +44,8 @@ import diffusers from diffusers import \ AutoencoderKL, \ UNet2DConditionModel +from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler +from transformers import T5EncoderModel from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT @@ -121,7 +124,7 @@ class StableDiffusion: self.device_state = None - self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline'] + self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] @@ -142,6 +145,7 @@ class StableDiffusion: self.is_v2 = model_config.is_v2 self.is_ssd = model_config.is_ssd self.is_vega = model_config.is_vega + self.is_pixart = model_config.is_pixart self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 @@ -157,7 +161,9 @@ class StableDiffusion: scheduler = get_sampler( 'ddpm', { "prediction_type": self.prediction_type, - }) + }, + 'sd' if not self.is_pixart else 'pixart' + ) self.noise_scheduler = scheduler # move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why @@ -227,7 +233,33 @@ class StableDiffusion: te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) flush() print("Injecting alt weights") + elif self.model_config.is_pixart: + # load the TE in 8bit mode + text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="text_encoder", + load_in_8bit=True, + device_map="auto", + torch_dtype=self.torch_dtype, + ) + # replace the to function with a no-op since it throws an error instead of a warning + text_encoder.to = lambda *args, **kwargs: None + pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained( + model_path, + text_encoder=text_encoder, + dtype=dtype, + device=self.device_torch, + **load_args + ).to(self.device_torch) + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + flush() + # text_encoder = pipe.text_encoder + # text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) + tokenizer = pipe.tokenizer else: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -273,7 +305,11 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - self.unet: 'UNet2DConditionModel' = pipe.unet + if self.is_pixart: + # pixart doesnt use a unet + self.unet = pipe.transformer + else: + self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.device_torch, dtype=dtype) self.vae.eval() self.vae.requires_grad_(False) @@ -381,7 +417,8 @@ class StableDiffusion: sampler, { "prediction_type": self.prediction_type, - } + }, + 'sd' if not self.is_pixart else 'pixart' ) try: @@ -425,6 +462,16 @@ class StableDiffusion: **extra_args ).to(self.device_torch) pipeline.watermark = None + elif self.is_pixart: + pipeline = PixArtAlphaPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ).to(self.device_torch) + else: pipeline = Pipe( vae=self.vae, @@ -615,6 +662,23 @@ class StableDiffusion: latents=gen_config.latents, **extra ).images[0] + elif self.is_pixart: + # needs attention masks for some reason + img = pipeline( + prompt=None, + prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), + negative_prompt=None, + # negative_prompt=gen_config.negative_prompt, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] else: img = pipeline( # prompt=gen_config.prompt, @@ -1005,12 +1069,53 @@ class StableDiffusion: f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") # predict the noise residual - noise_pred = self.unet( - latent_model_input.to(self.device_torch, self.torch_dtype), - timestep, - encoder_hidden_states=text_embeddings.text_embeds, - **kwargs, - ).sample + if self.is_pixart: + VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) + batch_size, ch, h, w = list(latents.shape) + + height = h * VAE_SCALE_FACTOR + width = w * VAE_SCALE_FACTOR + aspect_ratio_bin = ( + ASPECT_RATIO_1024_BIN if self.unet.config.sample_size == 128 else ASPECT_RATIO_512_BIN + ) + orig_height, orig_width = height, width + height, width = self.pipeline.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.unet.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) + resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + encoder_hidden_states=text_embeddings.text_embeds, + encoder_attention_mask=text_embeddings.attention_mask, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + **kwargs + )[0] + + # learned sigma + if self.unet.config.out_channels // 2 == self.unet.config.in_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + else: + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep, + encoder_hidden_states=text_embeddings.text_embeds, + **kwargs, + ).sample if do_classifier_free_guidance: # perform guidance @@ -1142,6 +1247,20 @@ class StableDiffusion: dropout_prob=dropout_prob, ) ) + elif self.is_pixart: + embeds, attention_mask = train_tools.encode_prompts_pixart( + self.tokenizer, + self.text_encoder, + prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob + ) + return PromptEmbeds( + embeds, + attention_mask=attention_mask, + ) + else: return PromptEmbeds( train_tools.encode_prompts( @@ -1489,6 +1608,11 @@ class StableDiffusion: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it + if self.is_pixart: + unet_has_grad = self.unet.proj_out.weight.requires_grad + else: + unet_has_grad = self.unet.conv_in.weight.requires_grad + self.device_state = { **empty_preset, 'vae': { @@ -1498,7 +1622,7 @@ class StableDiffusion: 'unet': { 'training': self.unet.training, 'device': self.unet.device, - 'requires_grad': self.unet.conv_in.weight.requires_grad, + 'requires_grad': unet_has_grad, }, } if isinstance(self.text_encoder, list): @@ -1511,10 +1635,15 @@ class StableDiffusion: 'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad }) else: + if isinstance(self.text_encoder, T5EncoderModel): + te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + else: + te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + self.device_state['text_encoder'] = { 'training': self.text_encoder.training, 'device': self.text_encoder.device, - 'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad + 'requires_grad': te_has_grad } if self.adapter is not None: if isinstance(self.adapter, IPAdapter): diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index b466296b..629d9ec5 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -29,6 +29,7 @@ from diffusers import ( from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import torch import re +from transformers import T5Tokenizer, T5EncoderModel SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -627,6 +628,48 @@ def encode_prompts( return text_embeddings +def encode_prompts_pixart( + tokenizer: 'T5Tokenizer', + text_encoder: 'T5EncoderModel', + prompts: list[str], + truncate: bool = True, + max_length=None, + dropout_prob=0.0, +): + if max_length is None: + # See Section 3.1. of the paper. + max_length = 120 + + if dropout_prob > 0.0: + # randomly drop out prompts + prompts = [ + prompt if torch.rand(1).item() > dropout_prob else "" for prompt in prompts + ] + + text_inputs = tokenizer( + prompts, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_length - 1: -1]) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(text_encoder.device) + + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device), attention_mask=prompt_attention_mask) + + return prompt_embeds.last_hidden_state, prompt_attention_mask + + # for XL def get_add_time_ids( height: int,