mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
WIP Flex 2 pipeline
This commit is contained in:
160
toolkit/models/flex2.py
Normal file
160
toolkit/models/flex2.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
from typing import List, Optional, Union
|
||||||
|
from diffusers import FluxPipeline
|
||||||
|
import torch
|
||||||
|
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers, logging
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
class Flex2Pipeline(FluxPipeline):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. <Prompt Start> "
|
||||||
|
|
||||||
|
# determine length of system prompt
|
||||||
|
self.system_prompt_length = self.tokenizer_2(
|
||||||
|
[self.system_prompt],
|
||||||
|
padding="longest",
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids[0].shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
device = device or self._execution_device
|
||||||
|
dtype = dtype or self.text_encoder.dtype
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
if isinstance(self, TextualInversionLoaderMixin):
|
||||||
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer_2(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length + self.system_prompt_length,
|
||||||
|
truncation=True,
|
||||||
|
return_length=False,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||||
|
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||||
|
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||||
|
f" {max_sequence_length + self.system_prompt_length} tokens: {removed_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_embeds = self.text_encoder_2(
|
||||||
|
text_input_ids,
|
||||||
|
attention_mask=prompt_attention_mask,
|
||||||
|
output_hidden_states=True
|
||||||
|
)
|
||||||
|
prompt_embeds = prompt_embeds.hidden_states[-1]
|
||||||
|
|
||||||
|
# remove the system prompt from the input and attention mask
|
||||||
|
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
|
||||||
|
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
||||||
|
|
||||||
|
dtype = self.text_encoder_2.dtype
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
_, seq_len, _ = prompt_embeds.shape
|
||||||
|
|
||||||
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def encode_prompt(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
prompt_2: Union[str, List[str]],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
lora_scale: Optional[float] = None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
prompt_2 (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||||
|
used in all text-encoders
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
num_images_per_prompt (`int`):
|
||||||
|
number of images that should be generated per prompt
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||||
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||||
|
lora_scale (`float`, *optional*):
|
||||||
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||||
|
"""
|
||||||
|
device = device or self._execution_device
|
||||||
|
|
||||||
|
# set lora scale so that monkey patched LoRA
|
||||||
|
# function of text encoder can correctly access it
|
||||||
|
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||||
|
self._lora_scale = lora_scale
|
||||||
|
|
||||||
|
# dynamically adjust the LoRA scale
|
||||||
|
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||||
|
scale_lora_layers(self.text_encoder, lora_scale)
|
||||||
|
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||||
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
if prompt_embeds is None:
|
||||||
|
prompt_2 = prompt_2 or prompt
|
||||||
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||||
|
|
||||||
|
# We only use the pooled prompt output from the CLIPTextModel
|
||||||
|
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||||
|
prompt=prompt,
|
||||||
|
device=device,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
)
|
||||||
|
prompt_embeds = self._get_llm_prompt_embeds(
|
||||||
|
prompt=prompt_2,
|
||||||
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.text_encoder is not None:
|
||||||
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||||
|
# Retrieve the original scale by scaling back the LoRA layers
|
||||||
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||||
|
|
||||||
|
if self.text_encoder_2 is not None:
|
||||||
|
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||||
|
# Retrieve the original scale by scaling back the LoRA layers
|
||||||
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||||
|
|
||||||
|
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||||
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||||
@@ -151,7 +151,7 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
prompt_embeds = text_encoder(
|
prompt_embeds = text_encoder(
|
||||||
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||||
)
|
)
|
||||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
prompt_embeds = prompt_embeds.hidden_states[-1]
|
||||||
|
|
||||||
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
|
prompt_embeds = prompt_embeds[:, self.system_prompt_length:]
|
||||||
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
||||||
|
|||||||
Reference in New Issue
Block a user