Files
stable-diffusion-webui-forge/backend/attention.py
layerdiffusion 3bc9402de9 ling
2024-07-29 11:48:55 -06:00

64 lines
1.6 KiB
Python

import torch
from backend.args import args
if args.xformers:
import xformers
import xformers.ops
def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_pytorch(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
attention_function = attention_pytorch
if args.xformers:
print("Using xformers cross attention")
attention_function = attention_xformers
else:
print("Using pytorch cross attention")
attention_function = attention_pytorch