mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added te aug adapter
This commit is contained in:
253
toolkit/models/te_aug_adapter.py
Normal file
253
toolkit/models/te_aug_adapter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
|
||||
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
|
||||
|
||||
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.resampler import Resampler
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
class TEAugAdapterCLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.attn_module_ref: weakref.ref = weakref.ref(attn_module)
|
||||
self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
|
||||
self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
|
||||
# copy the weights from the original module
|
||||
self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01
|
||||
self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01
|
||||
#reset the bias
|
||||
self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001
|
||||
self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001
|
||||
|
||||
self.zipper = ZipperModule(
|
||||
in_size=attn_module.embed_dim,
|
||||
in_tokens=77 * 2,
|
||||
out_size=attn_module.embed_dim,
|
||||
out_tokens=77,
|
||||
hidden_size=attn_module.embed_dim,
|
||||
hidden_tokens=77,
|
||||
)
|
||||
# self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data)
|
||||
# self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data)
|
||||
# #reset the bias
|
||||
# self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data)
|
||||
# self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data)
|
||||
|
||||
# replace the original forward with our forward
|
||||
self.original_forward = attn_module.forward
|
||||
attn_module.forward = self.forward
|
||||
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
attn_module = self.attn_module_ref()
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = attn_module.q_proj(hidden_states) * attn_module.scale
|
||||
key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz)
|
||||
value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim)
|
||||
query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref()
|
||||
if self.adapter_ref().is_active and adapter.conditional_embeds is not None:
|
||||
# apply the adapter
|
||||
|
||||
if adapter.is_unconditional_run:
|
||||
embeds = adapter.unconditional_embeds
|
||||
else:
|
||||
embeds = adapter.conditional_embeds
|
||||
# if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well
|
||||
if embeds.size(0) != bsz:
|
||||
embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0)
|
||||
|
||||
key_states_raw = self.k_proj_adapter(embeds)
|
||||
key_states = attn_module._shape(key_states_raw, -1, bsz)
|
||||
value_states_raw = self.v_proj_adapter(embeds)
|
||||
value_states = attn_module._shape(value_states_raw, -1, bsz)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
|
||||
attn_output_adapter = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
|
||||
f" {attn_output_adapter.size()}"
|
||||
)
|
||||
|
||||
attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
|
||||
attn_output_adapter = attn_output_adapter.transpose(1, 2)
|
||||
attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1))
|
||||
|
||||
# attn_output_adapter = attn_module.out_proj(attn_output_adapter)
|
||||
attn_output = attn_output + attn_output_adapter
|
||||
|
||||
attn_output = attn_module.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
class TEAugAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'CustomAdapter',
|
||||
sd: 'StableDiffusion',
|
||||
):
|
||||
super(TEAugAdapter, self).__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
|
||||
if isinstance(sd.text_encoder, list):
|
||||
raise ValueError("Dual text encoders is not yet supported")
|
||||
|
||||
# dim will come from text encoder
|
||||
# dim = sd.unet.config['cross_attention_dim']
|
||||
text_encoder: CLIPTextModel = sd.text_encoder
|
||||
dim = text_encoder.config.hidden_size
|
||||
|
||||
clip_encoder: CLIPEncoder = text_encoder.text_model.encoder
|
||||
# dim = clip_encoder.layers[-1].self_attn
|
||||
|
||||
if hasattr(adapter.vision_encoder.config, 'hidden_sizes'):
|
||||
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
|
||||
else:
|
||||
embedding_dim = adapter.vision_encoder.config.hidden_size
|
||||
|
||||
image_encoder_state_dict = adapter.vision_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
in_tokens = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
if adapter.config.image_encoder_arch.startswith('convnext'):
|
||||
in_tokens = 16 * 16
|
||||
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
|
||||
|
||||
out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens
|
||||
self.image_proj_model = ZipperModule(
|
||||
in_size=embedding_dim,
|
||||
in_tokens=in_tokens,
|
||||
out_size=dim,
|
||||
out_tokens=out_tokens,
|
||||
hidden_size=dim,
|
||||
hidden_tokens=out_tokens,
|
||||
)
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
for idx, layer in enumerate(clip_encoder.layers):
|
||||
name = f"clip_attention.{idx}"
|
||||
attn_procs[name] = TEAugAdapterCLIPAttention(
|
||||
layer.self_attn,
|
||||
self
|
||||
)
|
||||
|
||||
self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values()))
|
||||
|
||||
# make a getter to see if is active
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
# # apply the adapter
|
||||
input = self.image_proj_model(input)
|
||||
# self.embeds = input
|
||||
return input
|
||||
Reference in New Issue
Block a user