diff --git a/toolkit/models/flex2.py b/toolkit/models/flex2.py new file mode 100644 index 00000000..d8eae2e6 --- /dev/null +++ b/toolkit/models/flex2.py @@ -0,0 +1,160 @@ +from typing import List, Optional, Union +from diffusers import FluxPipeline +import torch +from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, logging +from transformers import AutoModel, AutoTokenizer + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class Flex2Pipeline(FluxPipeline): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " + + # determine length of system prompt + self.system_prompt_length = self.tokenizer_2( + [self.system_prompt], + padding="longest", + return_tensors="pt", + ).input_ids[0].shape[0] + + + def _get_llm_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length + self.system_prompt_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length + self.system_prompt_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2( + text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-1] + + # remove the system prompt from the input and attention mask + prompt_embeds = prompt_embeds[:, self.system_prompt_length:] + prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if self.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_llm_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids \ No newline at end of file diff --git a/toolkit/models/llm_adapter.py b/toolkit/models/llm_adapter.py index c35f6ae7..d3fd612d 100644 --- a/toolkit/models/llm_adapter.py +++ b/toolkit/models/llm_adapter.py @@ -151,7 +151,7 @@ class LLMAdapter(torch.nn.Module): prompt_embeds = text_encoder( text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True ) - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds.hidden_states[-1] prompt_embeds = prompt_embeds[:, self.system_prompt_length:] prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]