diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 96c9d800..044b7037 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1540,7 +1540,7 @@ class SDTrainer(BaseSDTrainProcess): prior_pred = prior_pred.detach() # do the custom adapter after the prior prediction - if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']): quad_count = random.randint(1, 4) self.adapter.train() conditional_embeds = self.adapter.condition_encoded_embeds( diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index d73570d2..a1cb7ff4 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -44,7 +44,7 @@ from transformers import ( ConvNextModel, ConvNextForImageClassification, ConvNextImageProcessor, - UMT5EncoderModel, LlamaTokenizerFast + UMT5EncoderModel, LlamaTokenizerFast, AutoModel, AutoTokenizer ) from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel @@ -52,6 +52,8 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio from transformers import ViTFeatureExtractor, ViTForImageClassification +from toolkit.models.llm_adapter import LLMAdapter + import torch.nn.functional as F @@ -198,6 +200,20 @@ class CustomAdapter(torch.nn.Module): raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}") self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer) + elif self.adapter_type == 'llm_adapter': + self.te = AutoModel.from_pretrained(self.config.text_encoder_path).to( + self.sd_ref().unet.device, + dtype=torch_dtype + ) + self.te.eval() + self.tokenizer = AutoTokenizer.from_pretrained(self.config.text_encoder_path) + self.llm_adapter = LLMAdapter( + adapter=self, + sd=self.sd_ref(), + llm=self.te, + tokenizer=self.tokenizer, + ) + self.llm_adapter.to(self.device, torch_dtype) elif self.adapter_type == 'te_augmenter': self.te_augmenter = TEAugAdapter(self, self.sd_ref()) elif self.adapter_type == 'vision_direct': @@ -238,7 +254,7 @@ class CustomAdapter(torch.nn.Module): def setup_clip(self): adapter_config = self.config sd = self.sd_ref() - if self.config.type == "text_encoder" or self.config.type == "single_value": + if self.config.type in ["text_encoder", "llm_adapter", "single_value"]: return if self.config.type == 'photo_maker': try: @@ -461,6 +477,9 @@ class CustomAdapter(torch.nn.Module): elif self.adapter_type == 'text_encoder': state_dict["te_adapter"] = self.te_adapter.state_dict() return state_dict + elif self.adapter_type == 'llm_adapter': + state_dict["llm_adapter"] = self.llm_adapter.state_dict() + return state_dict elif self.adapter_type == 'te_augmenter': if self.config.train_image_encoder: state_dict["vision_encoder"] = self.vision_encoder.state_dict() @@ -510,6 +529,14 @@ class CustomAdapter(torch.nn.Module): self.unconditional_embeds = self.te_adapter.encode_text(prompt).detach() else: self.conditional_embeds = self.te_adapter.encode_text(prompt).detach() + elif self.adapter_type == 'llm_adapter': + # todo allow for training + with torch.no_grad(): + # encode and save the embeds + if is_unconditional: + self.unconditional_embeds = self.llm_adapter.encode_text(prompt).detach() + else: + self.conditional_embeds = self.llm_adapter.encode_text(prompt).detach() return prompt elif self.adapter_type == 'photo_maker': if is_unconditional: @@ -613,11 +640,20 @@ class CustomAdapter(torch.nn.Module): quad_count=4, is_generating_samples=False, ) -> PromptEmbeds: - if self.adapter_type == 'text_encoder' and is_generating_samples: + if self.adapter_type == 'text_encoder': # replace the prompt embed with ours if is_unconditional: return self.unconditional_embeds.clone() return self.conditional_embeds.clone() + if self.adapter_type == 'llm_adapter': + # replace the prompt embed with ours + if is_unconditional: + prompt_embeds.text_embeds = self.unconditional_embeds.text_embeds.clone() + prompt_embeds.attention_mask = self.unconditional_embeds.attention_mask.clone() + return prompt_embeds + prompt_embeds.text_embeds = self.conditional_embeds.text_embeds.clone() + prompt_embeds.attention_mask = self.conditional_embeds.attention_mask.clone() + return prompt_embeds if self.adapter_type == 'ilora': return prompt_embeds @@ -977,6 +1013,8 @@ class CustomAdapter(torch.nn.Module): elif self.config.type == 'text_encoder': for attn_processor in self.te_adapter.adapter_modules: yield from attn_processor.parameters(recurse) + elif self.config.type == 'llm_adapter': + yield from self.llm_adapter.parameters(recurse) elif self.config.type == 'vision_direct': if self.config.train_scaler: # only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules) diff --git a/toolkit/models/llm_adapter.py b/toolkit/models/llm_adapter.py new file mode 100644 index 00000000..7efcd913 --- /dev/null +++ b/toolkit/models/llm_adapter.py @@ -0,0 +1,127 @@ +from functools import partial +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import List, Optional, Tuple, Union, TYPE_CHECKING + + +from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer + +from toolkit import train_tools +from toolkit.prompt_utils import PromptEmbeds +from diffusers import Transformer2DModel + + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline + from toolkit.custom_adapter import CustomAdapter + +LLM = Union[Qwen2Model, LlamaModel] +LLMTokenizer = Union[Qwen2Tokenizer, LlamaTokenizer] + + +def new_context_embedder_forward(self, x): + if self._adapter_ref().is_active: + x = self._context_embedder_ref()(x) + else: + x = self._orig_forward(x) + return x + + +class LLMAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + llm: LLM, + tokenizer: LLMTokenizer, + ): + super(LLMAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + self.llm_ref: weakref.ref = weakref.ref(llm) + self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) + + self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts." + + self.hidden_size = llm.config.hidden_size + + if sd.is_flux: + self.context_embedder = nn.Linear( + self.hidden_size, sd.unet.inner_dim) + self.sequence_length = 512 + sd.unet.context_embedder._orig_forward = sd.unet.context_embedder.forward + sd.unet.context_embedder.forward = partial( + new_context_embedder_forward, sd.unet.context_embedder) + sd.unet.context_embedder._context_embedder_ref = weakref.ref(self.context_embedder) + # add a is active property to the context embedder + sd.unet.context_embedder._adapter_ref = self.adapter_ref + + elif sd.is_lumina2: + self.context_embedder = nn.Linear( + self.hidden_size, sd.unet.hidden_size) + self.sequence_length = 256 + else: + raise ValueError( + "llm adapter currently only supports flux or lumina2") + + def _get_prompt_embeds( + self, + prompt: Union[str, List[str]], + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tokenizer = self.tokenizer_ref() + text_encoder = self.llm_ref() + device = text_encoder.device + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + dtype = text_encoder.dtype + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # make a getter to see if is active + + @property + def is_active(self): + return self.adapter_ref().is_active + + def encode_text(self, prompt): + + prompt = prompt if isinstance(prompt, list) else [prompt] + + prompt = [self.system_prompt + " " + p for p in prompt] + + prompt_embeds, prompt_attention_mask = self._get_prompt_embeds( + prompt=prompt, + max_sequence_length=self.sequence_length, + ) + + prompt_embeds = PromptEmbeds( + prompt_embeds, + attention_mask=prompt_attention_mask, + ).detach() + + return prompt_embeds + + def forward(self, input): + return input diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8b2a6506..212e68c2 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1356,8 +1356,7 @@ class StableDiffusion: conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) - if self.adapter is not None and isinstance(self.adapter, - CustomAdapter) and validation_image is not None: + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, prompt_embeds=conditional_embeds,