mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-10 23:49:57 +00:00
160 lines
7.1 KiB
Python
160 lines
7.1 KiB
Python
from typing import List, Optional, Union
|
|
from diffusers import FluxPipeline
|
|
import torch
|
|
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
|
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, logging
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
class Flex2Pipeline(FluxPipeline):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
|
|
|
# determine length of system prompt
|
|
self.system_prompt_length = self.tokenizer_2(
|
|
[self.system_prompt],
|
|
padding="longest",
|
|
return_tensors="pt",
|
|
).input_ids[0].shape[0]
|
|
|
|
|
|
def _get_llm_prompt_embeds(
|
|
self,
|
|
prompt: Union[str, List[str]] = None,
|
|
num_images_per_prompt: int = 1,
|
|
max_sequence_length: int = 512,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
device = device or self._execution_device
|
|
dtype = dtype or self.text_encoder.dtype
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
batch_size = len(prompt)
|
|
|
|
if isinstance(self, TextualInversionLoaderMixin):
|
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
|
|
|
text_inputs = self.tokenizer_2(
|
|
prompt,
|
|
padding="max_length",
|
|
max_length=max_sequence_length + self.system_prompt_length,
|
|
truncation=True,
|
|
return_length=False,
|
|
return_overflowing_tokens=False,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids.to(device)
|
|
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
|
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
|
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
|
logger.warning(
|
|
"The following part of your input was truncated because `max_sequence_length` is set to "
|
|
f" {max_sequence_length + self.system_prompt_length} tokens: {removed_text}"
|
|
)
|
|
|
|
prompt_embeds = self.text_encoder_2(
|
|
text_input_ids,
|
|
attention_mask=prompt_attention_mask,
|
|
output_hidden_states=True
|
|
)
|
|
prompt_embeds = prompt_embeds.hidden_states[-1]
|
|
|
|
# remove the system prompt from the input and attention mask
|
|
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
|
|
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
|
|
|
dtype = self.text_encoder_2.dtype
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
|
|
|
_, seq_len, _ = prompt_embeds.shape
|
|
|
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
|
|
return prompt_embeds
|
|
|
|
def encode_prompt(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
prompt_2: Union[str, List[str]],
|
|
device: Optional[torch.device] = None,
|
|
num_images_per_prompt: int = 1,
|
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
max_sequence_length: int = 512,
|
|
lora_scale: Optional[float] = None,
|
|
):
|
|
r"""
|
|
|
|
Args:
|
|
prompt (`str` or `List[str]`, *optional*):
|
|
prompt to be encoded
|
|
prompt_2 (`str` or `List[str]`, *optional*):
|
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
|
used in all text-encoders
|
|
device: (`torch.device`):
|
|
torch device
|
|
num_images_per_prompt (`int`):
|
|
number of images that should be generated per prompt
|
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
provided, text embeddings will be generated from `prompt` input argument.
|
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
|
lora_scale (`float`, *optional*):
|
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
|
"""
|
|
device = device or self._execution_device
|
|
|
|
# set lora scale so that monkey patched LoRA
|
|
# function of text encoder can correctly access it
|
|
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
|
self._lora_scale = lora_scale
|
|
|
|
# dynamically adjust the LoRA scale
|
|
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
|
scale_lora_layers(self.text_encoder, lora_scale)
|
|
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
if prompt_embeds is None:
|
|
prompt_2 = prompt_2 or prompt
|
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
|
|
|
# We only use the pooled prompt output from the CLIPTextModel
|
|
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
|
prompt=prompt,
|
|
device=device,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
)
|
|
prompt_embeds = self._get_llm_prompt_embeds(
|
|
prompt=prompt_2,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
max_sequence_length=max_sequence_length,
|
|
device=device,
|
|
)
|
|
|
|
if self.text_encoder is not None:
|
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
|
|
|
if self.text_encoder_2 is not None:
|
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
# Retrieve the original scale by scaling back the LoRA layers
|
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
|
|
|
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |