diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py index 477d192..599eed1 100644 --- a/jobs/process/TrainESRGANProcess.py +++ b/jobs/process/TrainESRGANProcess.py @@ -3,10 +3,12 @@ import glob import os import time from collections import OrderedDict +from typing import List, Optional from PIL import Image from PIL.ImageOps import exif_transpose -# from basicsr.archs.rrdbnet_arch import RRDBNet + +from toolkit.basic import flush from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys from safetensors.torch import save_file, load_file from torch.utils.data import DataLoader, ConcatDataset @@ -67,9 +69,10 @@ class TrainESRGANProcess(BaseTrainProcess): self.augmentations = self.get_conf('augmentations', {}) self.torch_dtype = get_torch_dtype(self.dtype) if self.torch_dtype == torch.bfloat16: - self.esrgan_dtype = torch.float16 + self.esrgan_dtype = torch.float32 else: self.esrgan_dtype = torch.float32 + self.vgg_19 = None self.style_weight_scalers = [] self.content_weight_scalers = [] @@ -232,6 +235,7 @@ class TrainESRGANProcess(BaseTrainProcess): pattern_size=self.zoom, dtype=self.torch_dtype ).to(self.device, dtype=self.torch_dtype) + self._pattern_loss = self._pattern_loss.to(self.device, dtype=self.torch_dtype) loss = torch.mean(self._pattern_loss(pred, target)) return loss @@ -269,13 +273,52 @@ class TrainESRGANProcess(BaseTrainProcess): if self.use_critic: self.critic.save(step) - def sample(self, step=None): + def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None): sample_folder = os.path.join(self.save_root, 'samples') if not os.path.exists(sample_folder): os.makedirs(sample_folder, exist_ok=True) + batch_sample_folder = os.path.join(self.save_root, 'samples_batch') + + batch_targets = None + batch_inputs = None + if batch is not None and not os.path.exists(batch_sample_folder): + os.makedirs(batch_sample_folder, exist_ok=True) self.model.eval() + def process_and_save(img, target_img, save_path): + output = self.model(img.to(self.device, dtype=self.esrgan_dtype)) + # output = (output / 2 + 0.5).clamp(0, 1) + output = output.clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + # convert to pillow image + output = Image.fromarray((output * 255).astype(np.uint8)) + + if isinstance(target_img, torch.Tensor): + # convert to pil + target_img = target_img.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + target_img = Image.fromarray((target_img * 255).astype(np.uint8)) + + # upscale to size * self.upscale_sample while maintaining pixels + output = output.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) + + width, height = output.size + + # stack input image and decoded image + target_image = target_img.resize((width, height)) + output = output.resize((width, height)) + + output_img = Image.new('RGB', (width * 2, height)) + output_img.paste(target_image, (0, 0)) + output_img.paste(output, (width, 0)) + + output_img.save(save_path) + with torch.no_grad(): for i, img_url in enumerate(self.sample_sources): img = exif_transpose(Image.open(img_url)) @@ -295,30 +338,6 @@ class TrainESRGANProcess(BaseTrainProcess): img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype) img = img - output = self.model(img) - # output = (output / 2 + 0.5).clamp(0, 1) - output = output.clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() - - # convert to pillow image - output = Image.fromarray((output * 255).astype(np.uint8)) - - # upscale to size * self.upscale_sample while maintaining pixels - output = output.resize( - (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), - resample=Image.NEAREST - ) - - width, height = output.size - - # stack input image and decoded image - target_image = target_image.resize((width, height)) - output = output.resize((width, height)) - - output_img = Image.new('RGB', (width * 2, height)) - output_img.paste(target_image, (0, 0)) - output_img.paste(output, (width, 0)) step_num = '' if step is not None: @@ -328,7 +347,23 @@ class TrainESRGANProcess(BaseTrainProcess): # zero-pad 2 digits i_str = str(i).zfill(2) filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" - output_img.save(os.path.join(sample_folder, filename)) + process_and_save(img, target_image, os.path.join(sample_folder, filename)) + + if batch is not None: + batch_targets = batch[0].detach() + batch_inputs = batch[1].detach() + batch_targets = torch.chunk(batch_targets, batch_targets.shape[0], dim=0) + batch_inputs = torch.chunk(batch_inputs, batch_inputs.shape[0], dim=0) + + for i in range(len(batch_inputs)): + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) self.model.train() @@ -445,35 +480,60 @@ class TrainESRGANProcess(BaseTrainProcess): print("Generating baseline samples") self.sample(step=0) # range start at self.epoch_num go to self.epochs + critic_losses = [] for epoch in range(self.epoch_num, self.epochs, 1): if self.step_num >= self.max_steps: break + flush() for targets, inputs in self.data_loader: if self.step_num >= self.max_steps: break with torch.no_grad(): - targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1) - inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1) + is_critic_only_step = False + if self.use_critic and 1 / (self.critic.num_critic_per_gen + 1) < np.random.uniform(): + is_critic_only_step = True - pred = self.model(inputs) + targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach() + inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1).detach() - pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1) - targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + optimizer.zero_grad() + # dont do grads here for critic step + do_grad = not is_critic_only_step + with torch.set_grad_enabled(do_grad): + pred = self.model(inputs) - # Run through VGG19 - if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: - stacked = torch.cat([pred, targets], dim=0) - # stacked = (stacked / 2 + 0.5).clamp(0, 1) - stacked = stacked.clamp(0, 1) - self.vgg_19(stacked) + pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + if torch.isnan(pred).any(): + raise ValueError('pred has nan values') + if torch.isnan(targets).any(): + raise ValueError('targets has nan values') - if self.use_critic: + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + stacked = torch.cat([pred, targets], dim=0) + # stacked = (stacked / 2 + 0.5).clamp(0, 1) + stacked = stacked.clamp(0, 1) + self.vgg_19(stacked) + # make sure we dont have nans + if torch.isnan(self.vgg19_pool_4.tensor).any(): + raise ValueError('vgg19_pool_4 has nan values') + + if is_critic_only_step: critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + critic_losses.append(critic_d_loss) + # don't do generator step + continue else: - critic_d_loss = 0.0 + # doing a regular step + if len(critic_losses) == 0: + critic_d_loss = 0 + else: + critic_d_loss = sum(critic_losses) / len(critic_losses) style_loss = self.get_style_loss() * self.style_weight content_loss = self.get_content_loss() * self.content_weight + mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight @@ -483,10 +543,13 @@ class TrainESRGANProcess(BaseTrainProcess): critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + # make sure non nan + if torch.isnan(loss): + raise ValueError('loss is nan') # Backward pass and optimization - optimizer.zero_grad() loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) optimizer.step() scheduler.step() @@ -549,7 +612,7 @@ class TrainESRGANProcess(BaseTrainProcess): if self.sample_every and self.step_num % self.sample_every == 0: # print above the progress bar self.print(f"Sampling at step {self.step_num}") - self.sample(self.step_num) + self.sample(self.step_num, batch=[targets, inputs]) if self.save_every and self.step_num % self.save_every == 0: # print above the progress bar diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py index a5ef92b..8cf438b 100644 --- a/jobs/process/models/vgg19_critic.py +++ b/jobs/process/models/vgg19_critic.py @@ -154,28 +154,28 @@ class Critic: # train critic here self.model.train() self.model.requires_grad_(True) + self.optimizer.zero_grad() critic_losses = [] - for i in range(self.num_critic_per_gen): - inputs = vgg_output.detach() - inputs = inputs.to(self.device, dtype=self.torch_dtype) - self.optimizer.zero_grad() + inputs = vgg_output.detach() + inputs = inputs.to(self.device, dtype=self.torch_dtype) + self.optimizer.zero_grad() - vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) - stacked_output = self.model(inputs) - out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + stacked_output = self.model(inputs).float() + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) - # Compute gradient penalty - gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + # Compute gradient penalty + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) - # Compute WGAN-GP critic loss - critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty - critic_loss.backward() - self.optimizer.zero_grad() - self.optimizer.step() - self.scheduler.step() - critic_losses.append(critic_loss.item()) + # Compute WGAN-GP critic loss + critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + critic_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) # avg loss loss = np.mean(critic_losses) diff --git a/run.py b/run.py index 54728c3..bb45cd5 100644 --- a/run.py +++ b/run.py @@ -5,6 +5,12 @@ from typing import Union, OrderedDict sys.path.insert(0, os.getcwd()) # must come before ANY torch or fastai imports # import toolkit.cuda_malloc + +# check if we have DEBUG_TOOLKIT in env +if os.environ.get("DEBUG_TOOLKIT", "0") == "1": + # set torch to trace mode + import torch + torch.autograd.set_detect_anomaly(True) import argparse from toolkit.job import get_job diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4979ae5..bf7d286 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -3,6 +3,8 @@ import time from typing import List, Optional, Literal, Union import random +from toolkit.prompt_utils import PromptEmbeds + ImgExt = Literal['jpg', 'png', 'webp'] @@ -447,3 +449,11 @@ class GenerateImageConfig: self.network_multiplier = float(content) elif flag == 'gr': self.guidance_rescale = float(content) + + def post_process_embeddings( + self, + conditional_prompt_embeds: PromptEmbeds, + unconditional_prompt_embeds: Optional[PromptEmbeds] = None, + ): + # this is called after prompt embeds are encoded. We can override them in the future here + pass \ No newline at end of file diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 8fe7897..bf3ffb9 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -53,6 +53,8 @@ class ImageDataset(Dataset, CaptionMixin): else: bad_count += 1 + self.file_list = new_file_list + print(f" - Found {len(self.file_list)} images") print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.path}" @@ -90,7 +92,10 @@ class ImageDataset(Dataset, CaptionMixin): scale_size = self.resolution else: scale_size = random.randint(self.resolution, int(min_img_size)) - img = img.resize((scale_size, scale_size), Image.BICUBIC) + scaler = scale_size / min_img_size + scale_width = int((img.width + 5) * scaler) + scale_height = int((img.height + 5) * scaler) + img = img.resize((scale_width, scale_height), Image.BICUBIC) img = transforms.RandomCrop(self.resolution)(img) else: img = transforms.CenterCrop(min_img_size)(img) diff --git a/toolkit/llvae.py b/toolkit/llvae.py index f8ed8f5..e1698ed 100644 --- a/toolkit/llvae.py +++ b/toolkit/llvae.py @@ -7,7 +7,7 @@ import itertools class LosslessLatentDecoder(nn.Module): def __init__(self, in_channels, latent_depth, dtype=torch.float32): super(LosslessLatentDecoder, self).__init__() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.latent_depth = latent_depth self.in_channels = in_channels self.out_channels = int(in_channels // (latent_depth * latent_depth)) @@ -46,7 +46,7 @@ class LosslessLatentDecoder(nn.Module): class LosslessLatentEncoder(nn.Module): def __init__(self, in_channels, latent_depth, dtype=torch.float32): super(LosslessLatentEncoder, self).__init__() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.latent_depth = latent_depth self.in_channels = in_channels self.out_channels = int(in_channels * (latent_depth * latent_depth)) @@ -108,7 +108,7 @@ if __name__ == '__main__': from PIL import Image import torchvision.transforms as transforms user_path = os.path.expanduser('~') - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 input_path = os.path.join(user_path, "Pictures/sample_2_512.png") diff --git a/toolkit/losses.py b/toolkit/losses.py index aded076..eeea357 100644 --- a/toolkit/losses.py +++ b/toolkit/losses.py @@ -27,11 +27,17 @@ class ComparativeTotalVariation(torch.nn.Module): # Gradient penalty def get_gradient_penalty(critic, real, fake, device): with torch.autocast(device_type='cuda'): - alpha = torch.rand(real.size(0), 1, 1, 1).to(device) + real = real.float() + fake = fake.float() + alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float() interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True) + if torch.isnan(interpolates).any(): + print('d_interpolates is nan') d_interpolates = critic(interpolates) fake = torch.ones(real.size(0), 1, device=device) - + + if torch.isnan(d_interpolates).any(): + print('fake is nan') gradients = torch.autograd.grad( outputs=d_interpolates, inputs=interpolates, @@ -41,10 +47,14 @@ def get_gradient_penalty(critic, real, fake, device): only_inputs=True, )[0] + # see if any are nan + if torch.isnan(gradients).any(): + print('gradients is nan') + gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = ((gradient_norm - 1) ** 2).mean() - return gradient_penalty + return gradient_penalty.float() class PatternLoss(torch.nn.Module): diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 9741600..c300e6e 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -44,6 +44,12 @@ class PromptEmbeds: self.pooled_embeds = self.pooled_embeds.detach() return self + def clone(self): + if self.pooled_embeds is not None: + return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()]) + else: + return PromptEmbeds(self.text_embeds.clone()) + class EncodedPromptPair: def __init__( diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 871f76d..3fd2310 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -368,6 +368,19 @@ class StableDiffusion: torch.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed) + # encode the prompt ourselves so we can do fun stuff with embeddings + conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) + + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + # todo do we disable text encoder here as well if disabled for model, or only do that for training? if self.is_xl: # fix guidance rescale for sdxl @@ -382,10 +395,14 @@ class StableDiffusion: extra['use_karras_sigmas'] = True img = pipeline( - prompt=gen_config.prompt, - prompt_2=gen_config.prompt_2, - negative_prompt=gen_config.negative_prompt, - negative_prompt_2=gen_config.negative_prompt_2, + # prompt=gen_config.prompt, + # prompt_2=gen_config.prompt_2, + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + # negative_prompt=gen_config.negative_prompt, + # negative_prompt_2=gen_config.negative_prompt_2, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, @@ -395,8 +412,10 @@ class StableDiffusion: ).images[0] else: img = pipeline( - prompt=gen_config.prompt, - negative_prompt=gen_config.negative_prompt, + # prompt=gen_config.prompt, + prompt_embeds=conditional_embeds.text_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + # negative_prompt=gen_config.negative_prompt, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, @@ -625,21 +644,25 @@ class StableDiffusion: # return latents_steps return latents - def encode_prompt(self, prompt, num_images_per_prompt=1) -> PromptEmbeds: + def encode_prompt(self, prompt, prompt2=None, num_images_per_prompt=1, force_all=False) -> PromptEmbeds: # sd1.5 embeddings are (bs, 77, 768) prompt = prompt # if it is not a list, make it one if not isinstance(prompt, list): prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] if self.is_xl: return PromptEmbeds( train_tools.encode_prompts_xl( self.tokenizer, self.text_encoder, prompt, + prompt2, num_images_per_prompt=num_images_per_prompt, - use_text_encoder_1=self.use_text_encoder_1, - use_text_encoder_2=self.use_text_encoder_2, + use_text_encoder_1=self.use_text_encoder_1 or force_all, + use_text_encoder_2=self.use_text_encoder_2 or force_all, ) ) else: diff --git a/toolkit/style.py b/toolkit/style.py index 4282a23..b08214a 100644 --- a/toolkit/style.py +++ b/toolkit/style.py @@ -33,12 +33,17 @@ class ContentLoss(nn.Module): # Define the separate loss function def separated_loss(y_pred, y_true): + y_pred = y_pred.float() + y_true = y_true.float() diff = torch.abs(y_pred - y_true) l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0 return 2. * l2 / content_size # Calculate itemized loss pred_itemized_loss = separated_loss(pred_layer, target_layer) + # check if is nan + if torch.isnan(pred_itemized_loss).any(): + print('pred_itemized_loss is nan') # Calculate the mean of itemized loss loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True) @@ -48,6 +53,7 @@ class ContentLoss(nn.Module): def convert_to_gram_matrix(inputs): + inputs = inputs.float() shape = inputs.size() batch, filters, height, width = shape[0], shape[1], shape[2], shape[3] size = height * width * filters @@ -93,11 +99,14 @@ class StyleLoss(nn.Module): target_grams = convert_to_gram_matrix(style_target) pred_grams = convert_to_gram_matrix(preds) itemized_loss = separated_loss(pred_grams, target_grams) + # check if is nan + if torch.isnan(itemized_loss).any(): + print('itemized_loss is nan') # reshape itemized loss to be (batch, 1, 1, 1) itemized_loss = torch.unsqueeze(itemized_loss, dim=1) # gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2]) loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True) - self.loss = loss.to(input_dtype) + self.loss = loss.to(input_dtype).float() return stacked_input.to(input_dtype) @@ -149,7 +158,7 @@ def get_style_model_and_losses( ): # content_layers = ['conv_4'] # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] - content_layers = ['conv4_2'] + content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2'] style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval() # set all weights in the model to our dtype diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 8bcea8a..5bee2f6 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -479,6 +479,7 @@ def encode_prompts_xl( tokenizers: list['CLIPTokenizer'], text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']], prompts: list[str], + prompts2: Union[list[str], None], num_images_per_prompt: int = 1, use_text_encoder_1: bool = True, # sdxl use_text_encoder_2: bool = True # sdxl @@ -486,11 +487,13 @@ def encode_prompts_xl( # text_encoder and text_encoder_2's penuultimate layer's output text_embeds_list = [] pooled_text_embeds = None # always text_encoder_2's pool + if prompts2 is None: + prompts2 = prompts for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)): # todo, we are using a blank string to ignore that encoder for now. # find a better way to do this (zeroing?, removing it from the unet?) - prompt_list_to_use = prompts + prompt_list_to_use = prompts if idx == 0 else prompts2 if idx == 0 and not use_text_encoder_1: prompt_list_to_use = ["" for _ in prompts] if idx == 1 and not use_text_encoder_2: @@ -515,6 +518,7 @@ def text_encode(text_encoder: 'CLIPTextModel', tokens): return text_encoder(tokens.to(text_encoder.device))[0] + def encode_prompts( tokenizer: 'CLIPTokenizer', text_encoder: 'CLIPTokenizer',