Added llm text encoder adapter

This commit is contained in:
Jaret Burkett
2025-02-13 08:28:32 -07:00
parent 2622de1e01
commit 7679105d52
4 changed files with 170 additions and 6 deletions

View File

@@ -44,7 +44,7 @@ from transformers import (
ConvNextModel,
ConvNextForImageClassification,
ConvNextImageProcessor,
UMT5EncoderModel, LlamaTokenizerFast
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
@@ -52,6 +52,8 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
from transformers import ViTFeatureExtractor, ViTForImageClassification
from toolkit.models.llm_adapter import LLMAdapter
import torch.nn.functional as F
@@ -198,6 +200,20 @@ class CustomAdapter(torch.nn.Module):
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
elif self.adapter_type == 'llm_adapter':
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
self.sd_ref().unet.device,
dtype=torch_dtype
)
self.te.eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
self.llm_adapter = LLMAdapter(
adapter=self,
sd=self.sd_ref(),
llm=self.te,
tokenizer=self.tokenizer,
)
self.llm_adapter.to(self.device, torch_dtype)
elif self.adapter_type == 'te_augmenter':
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
elif self.adapter_type == 'vision_direct':
@@ -238,7 +254,7 @@ class CustomAdapter(torch.nn.Module):
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type == "text_encoder" or self.config.type == "single_value":
if self.config.type in ["text_encoder", "llm_adapter", "single_value"]:
return
if self.config.type == 'photo_maker':
try:
@@ -461,6 +477,9 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'text_encoder':
state_dict["te_adapter"] = self.te_adapter.state_dict()
return state_dict
elif self.adapter_type == 'llm_adapter':
state_dict["llm_adapter"] = self.llm_adapter.state_dict()
return state_dict
elif self.adapter_type == 'te_augmenter':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
@@ -510,6 +529,14 @@ class CustomAdapter(torch.nn.Module):
self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach()
else:
self.conditional_embeds = self.te_adapter.encode_text(prompt).detach()
elif self.adapter_type == 'llm_adapter':
# todo allow for training
with torch.no_grad():
# encode and save the embeds
if is_unconditional:
self.unconditional_embeds = self.llm_adapter.encode_text(prompt).detach()
else:
self.conditional_embeds = self.llm_adapter.encode_text(prompt).detach()
return prompt
elif self.adapter_type == 'photo_maker':
if is_unconditional:
@@ -613,11 +640,20 @@ class CustomAdapter(torch.nn.Module):
quad_count=4,
is_generating_samples=False,
) -> PromptEmbeds:
if self.adapter_type == 'text_encoder' and is_generating_samples:
if self.adapter_type == 'text_encoder':
# replace the prompt embed with ours
if is_unconditional:
return self.unconditional_embeds.clone()
return self.conditional_embeds.clone()
if self.adapter_type == 'llm_adapter':
# replace the prompt embed with ours
if is_unconditional:
prompt_embeds.text_embeds = self.unconditional_embeds.text_embeds.clone()
prompt_embeds.attention_mask = self.unconditional_embeds.attention_mask.clone()
return prompt_embeds
prompt_embeds.text_embeds = self.conditional_embeds.text_embeds.clone()
prompt_embeds.attention_mask = self.conditional_embeds.attention_mask.clone()
return prompt_embeds
if self.adapter_type == 'ilora':
return prompt_embeds
@@ -977,6 +1013,8 @@ class CustomAdapter(torch.nn.Module):
elif self.config.type == 'text_encoder':
for attn_processor in self.te_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
elif self.config.type == 'llm_adapter':
yield from self.llm_adapter.parameters(recurse)
elif self.config.type == 'vision_direct':
if self.config.train_scaler:
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)