mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added llm text encoder adapter
This commit is contained in:
@@ -44,7 +44,7 @@ from transformers import (
|
||||
ConvNextModel,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor,
|
||||
UMT5EncoderModel, LlamaTokenizerFast
|
||||
UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
@@ -52,6 +52,8 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
from toolkit.models.llm_adapter import LLMAdapter
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@@ -198,6 +200,20 @@ class CustomAdapter(torch.nn.Module):
|
||||
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
||||
|
||||
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
||||
elif self.adapter_type == 'llm_adapter':
|
||||
self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to(
|
||||
self.sd_ref().unet.device,
|
||||
dtype=torch_dtype
|
||||
)
|
||||
self.te.eval()
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path)
|
||||
self.llm_adapter = LLMAdapter(
|
||||
adapter=self,
|
||||
sd=self.sd_ref(),
|
||||
llm=self.te,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
self.llm_adapter.to(self.device, torch_dtype)
|
||||
elif self.adapter_type == 'te_augmenter':
|
||||
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
@@ -238,7 +254,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type == "text_encoder" or self.config.type == "single_value":
|
||||
if self.config.type in ["text_encoder", "llm_adapter", "single_value"]:
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -461,6 +477,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
state_dict["te_adapter"] = self.te_adapter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'llm_adapter':
|
||||
state_dict["llm_adapter"] = self.llm_adapter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'te_augmenter':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
@@ -510,6 +529,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach()
|
||||
else:
|
||||
self.conditional_embeds = self.te_adapter.encode_text(prompt).detach()
|
||||
elif self.adapter_type == 'llm_adapter':
|
||||
# todo allow for training
|
||||
with torch.no_grad():
|
||||
# encode and save the embeds
|
||||
if is_unconditional:
|
||||
self.unconditional_embeds = self.llm_adapter.encode_text(prompt).detach()
|
||||
else:
|
||||
self.conditional_embeds = self.llm_adapter.encode_text(prompt).detach()
|
||||
return prompt
|
||||
elif self.adapter_type == 'photo_maker':
|
||||
if is_unconditional:
|
||||
@@ -613,11 +640,20 @@ class CustomAdapter(torch.nn.Module):
|
||||
quad_count=4,
|
||||
is_generating_samples=False,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'text_encoder' and is_generating_samples:
|
||||
if self.adapter_type == 'text_encoder':
|
||||
# replace the prompt embed with ours
|
||||
if is_unconditional:
|
||||
return self.unconditional_embeds.clone()
|
||||
return self.conditional_embeds.clone()
|
||||
if self.adapter_type == 'llm_adapter':
|
||||
# replace the prompt embed with ours
|
||||
if is_unconditional:
|
||||
prompt_embeds.text_embeds = self.unconditional_embeds.text_embeds.clone()
|
||||
prompt_embeds.attention_mask = self.unconditional_embeds.attention_mask.clone()
|
||||
return prompt_embeds
|
||||
prompt_embeds.text_embeds = self.conditional_embeds.text_embeds.clone()
|
||||
prompt_embeds.attention_mask = self.conditional_embeds.attention_mask.clone()
|
||||
return prompt_embeds
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
return prompt_embeds
|
||||
@@ -977,6 +1013,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.config.type == 'text_encoder':
|
||||
for attn_processor in self.te_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
elif self.config.type == 'llm_adapter':
|
||||
yield from self.llm_adapter.parameters(recurse)
|
||||
elif self.config.type == 'vision_direct':
|
||||
if self.config.train_scaler:
|
||||
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
|
||||
|
||||
Reference in New Issue
Block a user