from typing import List, Optional, Union from diffusers import FluxPipeline import torch from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, logging from transformers import AutoModel, AutoTokenizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name class Flex2Pipeline(FluxPipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " # determine length of system prompt self.system_prompt_length = self.tokenizer_2( [self.system_prompt], padding="longest", return_tensors="pt", ).input_ids[0].shape[0] def _get_llm_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length + self.system_prompt_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) prompt_attention_mask = text_inputs.attention_mask.to(device) untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length + self.system_prompt_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_2( text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True ) prompt_embeds = prompt_embeds.hidden_states[-1] # remove the system prompt from the input and attention mask prompt_embeds = prompt_embeds[:, self.system_prompt_length:] prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # We only use the pooled prompt output from the CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_llm_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids