mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
94 lines
4.0 KiB
Python
94 lines
4.0 KiB
Python
from typing import Optional
|
|
from diffusers.models.attention_processor import Attention
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class FluxSageAttnProcessor2_0:
|
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
|
|
|
def __init__(self):
|
|
if not hasattr(F, "scaled_dot_product_attention"):
|
|
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
) -> torch.FloatTensor:
|
|
from sageattention import sageattn
|
|
|
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
|
|
# `sample` projections.
|
|
query = attn.to_q(hidden_states)
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
|
|
inner_dim = key.shape[-1]
|
|
head_dim = inner_dim // attn.heads
|
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
if attn.norm_q is not None:
|
|
query = attn.norm_q(query)
|
|
if attn.norm_k is not None:
|
|
key = attn.norm_k(key)
|
|
|
|
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
|
if encoder_hidden_states is not None:
|
|
# `context` projections.
|
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
|
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
|
batch_size, -1, attn.heads, head_dim
|
|
).transpose(1, 2)
|
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
|
batch_size, -1, attn.heads, head_dim
|
|
).transpose(1, 2)
|
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
|
batch_size, -1, attn.heads, head_dim
|
|
).transpose(1, 2)
|
|
|
|
if attn.norm_added_q is not None:
|
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
|
if attn.norm_added_k is not None:
|
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
|
|
|
# attention
|
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
|
|
|
if image_rotary_emb is not None:
|
|
from diffusers.models.embeddings import apply_rotary_emb
|
|
|
|
query = apply_rotary_emb(query, image_rotary_emb)
|
|
key = apply_rotary_emb(key, image_rotary_emb)
|
|
|
|
hidden_states = sageattn(query, key, value, dropout_p=0.0, is_causal=False)
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
if encoder_hidden_states is not None:
|
|
encoder_hidden_states, hidden_states = (
|
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
|
)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
|
|
|
return hidden_states, encoder_hidden_states
|
|
else:
|
|
return hidden_states |