mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added llm text encoder adapter
This commit is contained in:
@@ -1540,7 +1540,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
prior_pred = prior_pred.detach()
|
prior_pred = prior_pred.detach()
|
||||||
|
|
||||||
# do the custom adapter after the prior prediction
|
# do the custom adapter after the prior prediction
|
||||||
if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image:
|
if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']):
|
||||||
quad_count = random.randint(1, 4)
|
quad_count = random.randint(1, 4)
|
||||||
self.adapter.train()
|
self.adapter.train()
|
||||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from transformers import (
|
|||||||
ConvNextModel,
|
ConvNextModel,
|
||||||
ConvNextForImageClassification,
|
ConvNextForImageClassification,
|
||||||
ConvNextImageProcessor,
|
ConvNextImageProcessor,
|
||||||
UMT5EncoderModel, LlamaTokenizerFast
|
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
|
||||||
)
|
)
|
||||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||||
|
|
||||||
@@ -52,6 +52,8 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
|
|||||||
|
|
||||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||||
|
|
||||||
|
from toolkit.models.llm_adapter import LLMAdapter
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
@@ -198,6 +200,20 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
||||||
|
|
||||||
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
||||||
|
elif self.adapter_type == 'llm_adapter':
|
||||||
|
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
|
||||||
|
self.sd_ref().unet.device,
|
||||||
|
dtype=torch_dtype
|
||||||
|
)
|
||||||
|
self.te.eval()
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
|
||||||
|
self.llm_adapter = LLMAdapter(
|
||||||
|
adapter=self,
|
||||||
|
sd=self.sd_ref(),
|
||||||
|
llm=self.te,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
)
|
||||||
|
self.llm_adapter.to(self.device, torch_dtype)
|
||||||
elif self.adapter_type == 'te_augmenter':
|
elif self.adapter_type == 'te_augmenter':
|
||||||
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
|
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
|
||||||
elif self.adapter_type == 'vision_direct':
|
elif self.adapter_type == 'vision_direct':
|
||||||
@@ -238,7 +254,7 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
def setup_clip(self):
|
def setup_clip(self):
|
||||||
adapter_config = self.config
|
adapter_config = self.config
|
||||||
sd = self.sd_ref()
|
sd = self.sd_ref()
|
||||||
if self.config.type == "text_encoder" or self.config.type == "single_value":
|
if self.config.type in ["text_encoder", "llm_adapter", "single_value"]:
|
||||||
return
|
return
|
||||||
if self.config.type == 'photo_maker':
|
if self.config.type == 'photo_maker':
|
||||||
try:
|
try:
|
||||||
@@ -461,6 +477,9 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
elif self.adapter_type == 'text_encoder':
|
elif self.adapter_type == 'text_encoder':
|
||||||
state_dict["te_adapter"] = self.te_adapter.state_dict()
|
state_dict["te_adapter"] = self.te_adapter.state_dict()
|
||||||
return state_dict
|
return state_dict
|
||||||
|
elif self.adapter_type == 'llm_adapter':
|
||||||
|
state_dict["llm_adapter"] = self.llm_adapter.state_dict()
|
||||||
|
return state_dict
|
||||||
elif self.adapter_type == 'te_augmenter':
|
elif self.adapter_type == 'te_augmenter':
|
||||||
if self.config.train_image_encoder:
|
if self.config.train_image_encoder:
|
||||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||||
@@ -510,6 +529,14 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach()
|
self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach()
|
||||||
else:
|
else:
|
||||||
self.conditional_embeds = self.te_adapter.encode_text(prompt).detach()
|
self.conditional_embeds = self.te_adapter.encode_text(prompt).detach()
|
||||||
|
elif self.adapter_type == 'llm_adapter':
|
||||||
|
# todo allow for training
|
||||||
|
with torch.no_grad():
|
||||||
|
# encode and save the embeds
|
||||||
|
if is_unconditional:
|
||||||
|
self.unconditional_embeds = self.llm_adapter.encode_text(prompt).detach()
|
||||||
|
else:
|
||||||
|
self.conditional_embeds = self.llm_adapter.encode_text(prompt).detach()
|
||||||
return prompt
|
return prompt
|
||||||
elif self.adapter_type == 'photo_maker':
|
elif self.adapter_type == 'photo_maker':
|
||||||
if is_unconditional:
|
if is_unconditional:
|
||||||
@@ -613,11 +640,20 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
quad_count=4,
|
quad_count=4,
|
||||||
is_generating_samples=False,
|
is_generating_samples=False,
|
||||||
) -> PromptEmbeds:
|
) -> PromptEmbeds:
|
||||||
if self.adapter_type == 'text_encoder' and is_generating_samples:
|
if self.adapter_type == 'text_encoder':
|
||||||
# replace the prompt embed with ours
|
# replace the prompt embed with ours
|
||||||
if is_unconditional:
|
if is_unconditional:
|
||||||
return self.unconditional_embeds.clone()
|
return self.unconditional_embeds.clone()
|
||||||
return self.conditional_embeds.clone()
|
return self.conditional_embeds.clone()
|
||||||
|
if self.adapter_type == 'llm_adapter':
|
||||||
|
# replace the prompt embed with ours
|
||||||
|
if is_unconditional:
|
||||||
|
prompt_embeds.text_embeds = self.unconditional_embeds.text_embeds.clone()
|
||||||
|
prompt_embeds.attention_mask = self.unconditional_embeds.attention_mask.clone()
|
||||||
|
return prompt_embeds
|
||||||
|
prompt_embeds.text_embeds = self.conditional_embeds.text_embeds.clone()
|
||||||
|
prompt_embeds.attention_mask = self.conditional_embeds.attention_mask.clone()
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
if self.adapter_type == 'ilora':
|
if self.adapter_type == 'ilora':
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
@@ -977,6 +1013,8 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
elif self.config.type == 'text_encoder':
|
elif self.config.type == 'text_encoder':
|
||||||
for attn_processor in self.te_adapter.adapter_modules:
|
for attn_processor in self.te_adapter.adapter_modules:
|
||||||
yield from attn_processor.parameters(recurse)
|
yield from attn_processor.parameters(recurse)
|
||||||
|
elif self.config.type == 'llm_adapter':
|
||||||
|
yield from self.llm_adapter.parameters(recurse)
|
||||||
elif self.config.type == 'vision_direct':
|
elif self.config.type == 'vision_direct':
|
||||||
if self.config.train_scaler:
|
if self.config.train_scaler:
|
||||||
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
|
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
|
||||||
|
|||||||
127
toolkit/models/llm_adapter.py
Normal file
127
toolkit/models/llm_adapter.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
from functools import partial
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import weakref
|
||||||
|
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
from toolkit import train_tools
|
||||||
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
from diffusers import Transformer2DModel
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
|
||||||
|
from toolkit.custom_adapter import CustomAdapter
|
||||||
|
|
||||||
|
LLM = Union[Qwen2Model, LlamaModel]
|
||||||
|
LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer]
|
||||||
|
|
||||||
|
|
||||||
|
def new_context_embedder_forward(self, x):
|
||||||
|
if self._adapter_ref().is_active:
|
||||||
|
x = self._context_embedder_ref()(x)
|
||||||
|
else:
|
||||||
|
x = self._orig_forward(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LLMAdapter(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
adapter: 'CustomAdapter',
|
||||||
|
sd: 'StableDiffusion',
|
||||||
|
llm: LLM,
|
||||||
|
tokenizer: LLMTokenizer,
|
||||||
|
):
|
||||||
|
super(LLMAdapter, self).__init__()
|
||||||
|
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||||
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||||
|
self.llm_ref: weakref.ref = weakref.ref(llm)
|
||||||
|
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||||
|
|
||||||
|
self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."
|
||||||
|
|
||||||
|
self.hidden_size = llm.config.hidden_size
|
||||||
|
|
||||||
|
if sd.is_flux:
|
||||||
|
self.context_embedder = nn.Linear(
|
||||||
|
self.hidden_size, sd.unet.inner_dim)
|
||||||
|
self.sequence_length = 512
|
||||||
|
sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward
|
||||||
|
sd.unet.context_embedder.forward = partial(
|
||||||
|
new_context_embedder_forward, sd.unet.context_embedder)
|
||||||
|
sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder)
|
||||||
|
# add a is active property to the context embedder
|
||||||
|
sd.unet.context_embedder._adapter_ref = self.adapter_ref
|
||||||
|
|
||||||
|
elif sd.is_lumina2:
|
||||||
|
self.context_embedder = nn.Linear(
|
||||||
|
self.hidden_size, sd.unet.hidden_size)
|
||||||
|
self.sequence_length = 256
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"llm adapter currently only supports flux or lumina2")
|
||||||
|
|
||||||
|
def _get_prompt_embeds(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
max_sequence_length: int = 256,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
tokenizer = self.tokenizer_ref()
|
||||||
|
text_encoder = self.llm_ref()
|
||||||
|
device = text_encoder.device
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
|
||||||
|
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||||
|
prompt_embeds = text_encoder(
|
||||||
|
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||||
|
)
|
||||||
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||||
|
|
||||||
|
dtype = text_encoder.dtype
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return prompt_embeds, prompt_attention_mask
|
||||||
|
|
||||||
|
# make a getter to see if is active
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self):
|
||||||
|
return self.adapter_ref().is_active
|
||||||
|
|
||||||
|
def encode_text(self, prompt):
|
||||||
|
|
||||||
|
prompt = prompt if isinstance(prompt, list) else [prompt]
|
||||||
|
|
||||||
|
prompt = [self.system_prompt + " <Prompt Start> " + p for p in prompt]
|
||||||
|
|
||||||
|
prompt_embeds, prompt_attention_mask = self._get_prompt_embeds(
|
||||||
|
prompt=prompt,
|
||||||
|
max_sequence_length=self.sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_embeds = PromptEmbeds(
|
||||||
|
prompt_embeds,
|
||||||
|
attention_mask=prompt_attention_mask,
|
||||||
|
).detach()
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return input
|
||||||
@@ -1356,8 +1356,7 @@ class StableDiffusion:
|
|||||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False)
|
||||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True)
|
||||||
|
|
||||||
if self.adapter is not None and isinstance(self.adapter,
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
||||||
CustomAdapter) and validation_image is not None:
|
|
||||||
conditional_embeds = self.adapter.condition_encoded_embeds(
|
conditional_embeds = self.adapter.condition_encoded_embeds(
|
||||||
tensors_0_1=validation_image,
|
tensors_0_1=validation_image,
|
||||||
prompt_embeds=conditional_embeds,
|
prompt_embeds=conditional_embeds,
|
||||||
|
|||||||
Reference in New Issue
Block a user