mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
fixed various issues with llm attention masking. Added block training on the llm adapter.
This commit is contained in:
@@ -216,6 +216,9 @@ class AdapterConfig:
|
|||||||
self.conv_pooling: bool = kwargs.get('conv_pooling', False)
|
self.conv_pooling: bool = kwargs.get('conv_pooling', False)
|
||||||
self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1)
|
self.conv_pooling_stacks: int = kwargs.get('conv_pooling_stacks', 1)
|
||||||
self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None)
|
self.sparse_autoencoder_dim: Optional[int] = kwargs.get('sparse_autoencoder_dim', None)
|
||||||
|
|
||||||
|
# for llm adapter
|
||||||
|
self.num_cloned_blocks: int = kwargs.get('num_cloned_blocks', 0)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ class CustomAdapter(torch.nn.Module):
|
|||||||
sd=self.sd_ref(),
|
sd=self.sd_ref(),
|
||||||
llm=self.te,
|
llm=self.te,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
|
num_cloned_blocks=self.config.num_cloned_blocks,
|
||||||
)
|
)
|
||||||
self.llm_adapter.to(self.device, torch_dtype)
|
self.llm_adapter.to(self.device, torch_dtype)
|
||||||
elif self.adapter_type == 'te_augmenter':
|
elif self.adapter_type == 'te_augmenter':
|
||||||
|
|||||||
@@ -5,14 +5,15 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import weakref
|
import weakref
|
||||||
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
from typing import Any, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock
|
||||||
from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer
|
from transformers import AutoModel, AutoTokenizer, Qwen2Model, LlamaModel, Qwen2Tokenizer, LlamaTokenizer
|
||||||
|
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from diffusers import Transformer2DModel
|
from diffusers import Transformer2DModel
|
||||||
|
from toolkit.dequantize import patch_dequantization_on_save
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -30,6 +31,19 @@ def new_context_embedder_forward(self, x):
|
|||||||
x = self._orig_forward(x)
|
x = self._orig_forward(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def new_block_forward(
|
||||||
|
self: FluxTransformerBlock,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
temb: torch.Tensor,
|
||||||
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self._adapter_ref().is_active:
|
||||||
|
return self._new_block_ref()(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs)
|
||||||
|
else:
|
||||||
|
return self._orig_forward(hidden_states, encoder_hidden_states, temb, image_rotary_emb, joint_attention_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LLMAdapter(torch.nn.Module):
|
class LLMAdapter(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -38,12 +52,15 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
sd: 'StableDiffusion',
|
sd: 'StableDiffusion',
|
||||||
llm: LLM,
|
llm: LLM,
|
||||||
tokenizer: LLMTokenizer,
|
tokenizer: LLMTokenizer,
|
||||||
|
num_cloned_blocks: int = 0,
|
||||||
):
|
):
|
||||||
super(LLMAdapter, self).__init__()
|
super(LLMAdapter, self).__init__()
|
||||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||||
self.llm_ref: weakref.ref = weakref.ref(llm)
|
self.llm_ref: weakref.ref = weakref.ref(llm)
|
||||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||||
|
self.num_cloned_blocks = num_cloned_blocks
|
||||||
|
self.apply_embedding_mask = False
|
||||||
# make sure we can pad
|
# make sure we can pad
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@@ -65,8 +82,11 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
print(f"System prompt length: {self.system_prompt_length}")
|
print(f"System prompt length: {self.system_prompt_length}")
|
||||||
|
|
||||||
self.hidden_size = llm.config.hidden_size
|
self.hidden_size = llm.config.hidden_size
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
|
||||||
if sd.is_flux:
|
if sd.is_flux:
|
||||||
|
self.apply_embedding_mask = True
|
||||||
self.context_embedder = nn.Linear(
|
self.context_embedder = nn.Linear(
|
||||||
self.hidden_size, sd.unet.inner_dim)
|
self.hidden_size, sd.unet.inner_dim)
|
||||||
self.sequence_length = 512
|
self.sequence_length = 512
|
||||||
@@ -77,6 +97,25 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
# add a is active property to the context embedder
|
# add a is active property to the context embedder
|
||||||
sd.unet.context_embedder._adapter_ref = self.adapter_ref
|
sd.unet.context_embedder._adapter_ref = self.adapter_ref
|
||||||
|
|
||||||
|
for idx in range(self.num_cloned_blocks):
|
||||||
|
block = FluxTransformerBlock(
|
||||||
|
dim=sd.unet.inner_dim,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
)
|
||||||
|
# patch it in case it is quantized
|
||||||
|
patch_dequantization_on_save(sd.unet.transformer_blocks[idx])
|
||||||
|
state_dict = sd.unet.transformer_blocks[idx].state_dict()
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
block.state_dict()[key].copy_(value)
|
||||||
|
blocks.append(block)
|
||||||
|
orig_block = sd.unet.transformer_blocks[idx]
|
||||||
|
orig_block._orig_forward = orig_block.forward
|
||||||
|
orig_block.forward = partial(
|
||||||
|
new_block_forward, orig_block)
|
||||||
|
orig_block._new_block_ref = weakref.ref(block)
|
||||||
|
orig_block._adapter_ref = self.adapter_ref
|
||||||
|
|
||||||
elif sd.is_lumina2:
|
elif sd.is_lumina2:
|
||||||
self.context_embedder = nn.Linear(
|
self.context_embedder = nn.Linear(
|
||||||
self.hidden_size, sd.unet.hidden_size)
|
self.hidden_size, sd.unet.hidden_size)
|
||||||
@@ -84,6 +123,8 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"llm adapter currently only supports flux or lumina2")
|
"llm adapter currently only supports flux or lumina2")
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
def _get_prompt_embeds(
|
def _get_prompt_embeds(
|
||||||
self,
|
self,
|
||||||
@@ -103,11 +144,12 @@ class LLMAdapter(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
text_input_ids = text_inputs.input_ids.to(device)
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
|
||||||
# remove the system prompt from the input
|
|
||||||
text_input_ids = text_input_ids[:, self.system_prompt_length:]
|
|
||||||
|
|
||||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||||
|
|
||||||
|
# remove the system prompt from the input and attention mask
|
||||||
|
text_input_ids = text_input_ids[:, self.system_prompt_length:]
|
||||||
|
prompt_attention_mask = prompt_attention_mask[:, self.system_prompt_length:]
|
||||||
|
|
||||||
prompt_embeds = text_encoder(
|
prompt_embeds = text_encoder(
|
||||||
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user