Rework ip adapter and vision direct adapters to apply to the single transformer blocks even though they are not cross attn.

This commit is contained in:
Jaret Burkett
2024-09-01 10:40:42 -06:00
parent 7ed8c51f20
commit 7d9ab22405
2 changed files with 129 additions and 147 deletions

View File

@@ -269,16 +269,7 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
is_active = self.adapter_ref().is_active
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
@@ -297,7 +288,44 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
if attn.norm_k is not None:
key = attn.norm_k(key)
# will be none if disabled
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# begin ip adapter
if not is_active:
ip_hidden_states = None
else:
@@ -309,47 +337,6 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
raise ValueError("Unconditional is None but should not be")
ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# do ip adapter
# will be none if disabled
if ip_hidden_states is not None:
# apply scaler
if self.train_scaler:
@@ -365,8 +352,6 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
@@ -376,26 +361,23 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
scale = self.scale
hidden_states = hidden_states + scale * ip_hidden_states
# end ip adapter
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
class IPAdapter(torch.nn.Module):
@@ -659,9 +641,9 @@ class IPAdapter(torch.nn.Module):
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
# single transformer blocks do not have cross attn
# for i, module in transformer.single_transformer_blocks.named_children():
# attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
# single transformer blocks do not have cross attn, but we will do them anyway
for i, module in transformer.single_transformer_blocks.named_children():
attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
@@ -695,7 +677,7 @@ class IPAdapter(torch.nn.Module):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
elif name.startswith("transformer") or name.startswith("single_transformer"):
if is_flux:
hidden_size = 3072
else:
@@ -773,11 +755,20 @@ class IPAdapter(torch.nn.Module):
transformer: FluxTransformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
# do single blocks too even though they dont have cross attn
for i, module in transformer.single_transformer_blocks.named_children():
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
])
] + [
transformer.single_transformer_blocks[i].attn.processor for i in
range(len(transformer.single_transformer_blocks))
]
)
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())

View File

@@ -323,17 +323,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
is_active = self.adapter_ref().is_active
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
@@ -352,36 +342,34 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
@@ -391,8 +379,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# do ip adapter
# will be none if disabled
# begin ip adapter
if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds
if adapter_hidden_states.shape[0] < batch_size:
@@ -413,8 +400,6 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
vd_hidden_states = F.scaled_dot_product_attention(
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
@@ -424,25 +409,21 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states + self.scale * vd_hidden_states
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
return hidden_states
class VisionDirectAdapter(torch.nn.Module):
def __init__(
@@ -482,9 +463,9 @@ class VisionDirectAdapter(torch.nn.Module):
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
# single transformer blocks do not have cross attn
# for i, module in transformer.single_transformer_blocks.named_children():
# attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
# single transformer blocks do not have cross attn, but we will do them anyway
for i, module in transformer.single_transformer_blocks.named_children():
attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
@@ -501,7 +482,7 @@ class VisionDirectAdapter(torch.nn.Module):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
elif name.startswith("transformer") or name.startswith("single_transformer"):
if is_flux:
hidden_size = 3072
else:
@@ -596,23 +577,32 @@ class VisionDirectAdapter(torch.nn.Module):
transformer: FluxTransformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
# do single blocks too even though they dont have cross attn
for i, module in transformer.single_transformer_blocks.named_children():
module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
])
] + [
transformer.single_transformer_blocks[i].attn.processor for i in
range(len(transformer.single_transformer_blocks))
]
)
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# add the mlp layer
self.mlp = MLP(
in_dim=self.token_size,
out_dim=self.token_size,
hidden_dim=self.token_size,
# dropout=0.1,
use_residual=True
)
# # add the mlp layer
# self.mlp = MLP(
# in_dim=self.token_size,
# out_dim=self.token_size,
# hidden_dim=self.token_size,
# # dropout=0.1,
# use_residual=True
# )
# make a getter to see if is active
@property
@@ -620,4 +610,5 @@ class VisionDirectAdapter(torch.nn.Module):
return self.adapter_ref().is_active
def forward(self, input):
return self.mlp(input)
# return self.mlp(input)
return input