Added llm text encoder adapter

This commit is contained in:
Jaret Burkett
2025-02-13 08:28:32 -07:00
parent 2622de1e01
commit 7679105d52
4 changed files with 170 additions and 6 deletions

View File

@@ -44,7 +44,7 @@ from transformers import (
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor,
UMT5EncoderModel, LlamaTokenizerFast
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
@@ -52,6 +52,8 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
from transformers import ViTFeatureExtractor, ViTForImageClassification
from toolkit.models.llm_adapter import LLMAdapter
import torch.nn.functional as F
@@ -198,6 +200,20 @@ class CustomAdapter(torch.nn.Module):
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
elif self.adapter_type == 'llm_adapter':
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
self.sd_ref().unet.device,
dtype=torch_dtype
)
self.te.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
self.llm_adapter = LLMAdapter(
adapter=self,
sd=self.sd_ref(),
llm=self.te,
tokenizer=self.tokenizer,
)
self.llm_adapter.to(self.device, torch_dtype)
elif self.adapter_type == 'te_augmenter':
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
elif self.adapter_type == 'vision_direct':
@@ -238,7 +254,7 @@ class CustomAdapter(torch.nn.Module):
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type == "text_encoder" or self.config.type == "single_value":
if self.config.type in ["text_encoder", "llm_adapter", "single_value"]:
return
if self.config.type == 'photo_maker':
try:
@@ -461,6 +477,9 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'text_encoder':
state_dict["te_adapter"] = self.te_adapter.state_dict()
return state_dict
elif self.adapter_type == 'llm_adapter':
state_dict["llm_adapter"] = self.llm_adapter.state_dict()
return state_dict
elif self.adapter_type == 'te_augmenter':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
@@ -510,6 +529,14 @@ class CustomAdapter(torch.nn.Module):
self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach()
else:
self.conditional_embeds = self.te_adapter.encode_text(prompt).detach()
elif self.adapter_type == 'llm_adapter':
# todo allow for training
with torch.no_grad():
# encode and save the embeds
if is_unconditional:
self.unconditional_embeds = self.llm_adapter.encode_text(prompt).detach()
else:
self.conditional_embeds = self.llm_adapter.encode_text(prompt).detach()
return prompt
elif self.adapter_type == 'photo_maker':
if is_unconditional:
@@ -613,11 +640,20 @@ class CustomAdapter(torch.nn.Module):
quad_count=4,
is_generating_samples=False,
) -> PromptEmbeds:
if self.adapter_type == 'text_encoder' and is_generating_samples:
if self.adapter_type == 'text_encoder':
# replace the prompt embed with ours
if is_unconditional:
return self.unconditional_embeds.clone()
return self.conditional_embeds.clone()
if self.adapter_type == 'llm_adapter':
# replace the prompt embed with ours
if is_unconditional:
prompt_embeds.text_embeds = self.unconditional_embeds.text_embeds.clone()
prompt_embeds.attention_mask = self.unconditional_embeds.attention_mask.clone()
return prompt_embeds
prompt_embeds.text_embeds = self.conditional_embeds.text_embeds.clone()
prompt_embeds.attention_mask = self.conditional_embeds.attention_mask.clone()
return prompt_embeds
if self.adapter_type == 'ilora':
return prompt_embeds
@@ -977,6 +1013,8 @@ class CustomAdapter(torch.nn.Module):
elif self.config.type == 'text_encoder':
for attn_processor in self.te_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
elif self.config.type == 'llm_adapter':
yield from self.llm_adapter.parameters(recurse)
elif self.config.type == 'vision_direct':
if self.config.train_scaler:
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)

View File

@@ -0,0 +1,127 @@
from functools import partial
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import weakref
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer
from toolkit import train_tools
from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
from toolkit.custom_adapter import CustomAdapter
LLM = Union[Qwen2Model, LlamaModel]
LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer]
def new_context_embedder_forward(self, x):
if self._adapter_ref().is_active:
x = self._context_embedder_ref()(x)
else:
x = self._orig_forward(x)
return x
class LLMAdapter(torch.nn.Module):
def __init__(
self,
adapter: 'CustomAdapter',
sd: 'StableDiffusion',
llm: LLM,
tokenizer: LLMTokenizer,
):
super(LLMAdapter, self).__init__()
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref: weakref.ref = weakref.ref(sd)
self.llm_ref: weakref.ref = weakref.ref(llm)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."
self.hidden_size = llm.config.hidden_size
if sd.is_flux:
self.context_embedder = nn.Linear(
self.hidden_size, sd.unet.inner_dim)
self.sequence_length = 512
sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward
sd.unet.context_embedder.forward = partial(
new_context_embedder_forward, sd.unet.context_embedder)
sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder)
# add a is active property to the context embedder
sd.unet.context_embedder._adapter_ref = self.adapter_ref
elif sd.is_lumina2:
self.context_embedder = nn.Linear(
self.hidden_size, sd.unet.hidden_size)
self.sequence_length = 256
else:
raise ValueError(
"llm adapter currently only supports flux or lumina2")
def _get_prompt_embeds(
self,
prompt: Union[str, List[str]],
max_sequence_length: int = 256,
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenizer = self.tokenizer_ref()
text_encoder = self.llm_ref()
device = text_encoder.device
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
)
prompt_embeds = prompt_embeds.hidden_states[-2]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
# make a getter to see if is active
@property
def is_active(self):
return self.adapter_ref().is_active
def encode_text(self, prompt):
prompt = prompt if isinstance(prompt, list) else [prompt]
prompt = [self.system_prompt + " <Prompt Start> " + p for p in prompt]
prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
prompt=prompt,
max_sequence_length=self.sequence_length,
)
prompt_embeds = PromptEmbeds(
prompt_embeds,
attention_mask=prompt_attention_mask,
).detach()
return prompt_embeds
def forward(self, input):
return input

View File

@@ -1356,8 +1356,7 @@ class StableDiffusion:
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
if self.adapter is not None and isinstance(self.adapter,
CustomAdapter) and validation_image is not None:
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
conditional_embeds = self.adapter.condition_encoded_embeds(
tensors_0_1=validation_image,
prompt_embeds=conditional_embeds,