mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-27 03:19:47 +00:00
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
import torch
|
|
|
|
from backend.args import args
|
|
|
|
if args.xformers:
|
|
import xformers
|
|
import xformers.ops
|
|
|
|
|
|
def attention_xformers(q, k, v, heads, mask=None):
|
|
b, _, dim_head = q.shape
|
|
dim_head //= heads
|
|
|
|
q, k, v = map(
|
|
lambda t: t.unsqueeze(3)
|
|
.reshape(b, -1, heads, dim_head)
|
|
.permute(0, 2, 1, 3)
|
|
.reshape(b * heads, -1, dim_head)
|
|
.contiguous(),
|
|
(q, k, v),
|
|
)
|
|
|
|
if mask is not None:
|
|
pad = 8 - q.shape[1] % 8
|
|
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
|
mask_out[:, :, :mask.shape[-1]] = mask
|
|
mask = mask_out[:, :, :mask.shape[-1]]
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
|
|
|
out = (
|
|
out.unsqueeze(0)
|
|
.reshape(b, heads, -1, dim_head)
|
|
.permute(0, 2, 1, 3)
|
|
.reshape(b, -1, heads * dim_head)
|
|
)
|
|
return out
|
|
|
|
|
|
def attention_pytorch(q, k, v, heads, mask=None):
|
|
b, _, dim_head = q.shape
|
|
dim_head //= heads
|
|
|
|
q, k, v = map(
|
|
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
|
(q, k, v),
|
|
)
|
|
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
|
|
|
out = (
|
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
|
)
|
|
return out
|
|
|
|
|
|
attention_function = attention_pytorch
|
|
|
|
if args.xformers:
|
|
print("Using xformers cross attention")
|
|
attention_function = attention_xformers
|
|
else:
|
|
print("Using pytorch cross attention")
|
|
attention_function = attention_pytorch
|
|
|
|
|
|
class AttentionProcessorForge:
|
|
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None, temb=None, *args, **kwargs):
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
hidden_states = attention_function(query, key, value, heads=attn.heads, mask=attention_mask)
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|