From bd8d7dc0817590cfc22ff97bb2dbe66034ab48e7 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 14 Feb 2025 11:24:01 -0700 Subject: [PATCH] fixed various issues with llm attention masking. Added block training on the llm adapter. --- toolkit/config_modules.py | 3 ++ toolkit/custom_adapter.py | 1 + toolkit/models/llm_adapter.py | 54 +++++++++++++++++++++++++++++++---- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 284f4fdb..bee186e7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -216,6 +216,9 @@ class AdapterConfig: self.conv_pooling: bool = kwargs.get('conv_pooling', False) self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1) self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None) + + # for llm adapter + self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0) class EmbeddingConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index a1cb7ff4..c2c03952 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -212,6 +212,7 @@ class CustomAdapter(torch.nn.Module): sd=self.sd_ref(), llm=self.te, tokenizer=self.tokenizer, + num_cloned_blocks=self.config.num_cloned_blocks, ) self.llm_adapter.to(self.device, torch_dtype) elif self.adapter_type == 'te_augmenter': diff --git a/toolkit/models/llm_adapter.py b/toolkit/models/llm_adapter.py index 51880ffa..97ed2455 100644 --- a/toolkit/models/llm_adapter.py +++ b/toolkit/models/llm_adapter.py @@ -5,14 +5,15 @@ import torch import torch.nn as nn import torch.nn.functional as F import weakref -from typing import List, Optional, Tuple, Union, TYPE_CHECKING - +from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from diffusers.models.transformers.transformer_flux import FluxTransformerBlock from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer from toolkit import train_tools from toolkit.prompt_utils import PromptEmbeds from diffusers import Transformer2DModel +from toolkit.dequantize import patch_dequantization_on_save if TYPE_CHECKING: @@ -30,6 +31,19 @@ def new_context_embedder_forward(self, x): x = self._orig_forward(x) return x +def new_block_forward( + self: FluxTransformerBlock, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if self._adapter_ref().is_active: + return self._new_block_ref()(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) + else: + return self._orig_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs) + class LLMAdapter(torch.nn.Module): def __init__( @@ -38,12 +52,15 @@ class LLMAdapter(torch.nn.Module): sd: 'StableDiffusion', llm: LLM, tokenizer: LLMTokenizer, + num_cloned_blocks: int = 0, ): super(LLMAdapter, self).__init__() self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref: weakref.ref = weakref.ref(sd) self.llm_ref: weakref.ref = weakref.ref(llm) self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer) + self.num_cloned_blocks = num_cloned_blocks + self.apply_embedding_mask = False # make sure we can pad if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -65,8 +82,11 @@ class LLMAdapter(torch.nn.Module): print(f"System prompt length: {self.system_prompt_length}") self.hidden_size = llm.config.hidden_size + + blocks = [] if sd.is_flux: + self.apply_embedding_mask = True self.context_embedder = nn.Linear( self.hidden_size, sd.unet.inner_dim) self.sequence_length = 512 @@ -77,6 +97,25 @@ class LLMAdapter(torch.nn.Module): # add a is active property to the context embedder sd.unet.context_embedder._adapter_ref = self.adapter_ref + for idx in range(self.num_cloned_blocks): + block = FluxTransformerBlock( + dim=sd.unet.inner_dim, + num_attention_heads=24, + attention_head_dim=128, + ) + # patch it in case it is quantized + patch_dequantization_on_save(sd.unet.transformer_blocks[idx]) + state_dict = sd.unet.transformer_blocks[idx].state_dict() + for key, value in state_dict.items(): + block.state_dict()[key].copy_(value) + blocks.append(block) + orig_block = sd.unet.transformer_blocks[idx] + orig_block._orig_forward = orig_block.forward + orig_block.forward = partial( + new_block_forward, orig_block) + orig_block._new_block_ref = weakref.ref(block) + orig_block._adapter_ref = self.adapter_ref + elif sd.is_lumina2: self.context_embedder = nn.Linear( self.hidden_size, sd.unet.hidden_size) @@ -84,6 +123,8 @@ class LLMAdapter(torch.nn.Module): else: raise ValueError( "llm adapter currently only supports flux or lumina2") + + self.blocks = nn.ModuleList(blocks) def _get_prompt_embeds( self, @@ -103,11 +144,12 @@ class LLMAdapter(torch.nn.Module): ) text_input_ids = text_inputs.input_ids.to(device) - - # remove the system prompt from the input - text_input_ids = text_input_ids[:, self.system_prompt_length:] - prompt_attention_mask = text_inputs.attention_mask.to(device) + + # remove the system prompt from the input and attention mask + text_input_ids = text_input_ids[:, self.system_prompt_length:] + prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:] + prompt_embeds = text_encoder( text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True )