mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 08:59:51 +00:00
implement attention for new backend
This commit is contained in:
@@ -28,12 +28,7 @@ fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp16", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp32", action="store_true")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--attention-split", action="store_true")
|
||||
attn_group.add_argument("--attention-quad", action="store_true")
|
||||
attn_group.add_argument("--attention-pytorch", action="store_true")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true")
|
||||
parser.add_argument("--xformers", action="store_true")
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1)
|
||||
parser.add_argument("--disable-ipex-hijack", action="store_true")
|
||||
|
||||
|
||||
65
backend/attention.py
Normal file
65
backend/attention.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
|
||||
from backend.args import args
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional
|
||||
|
||||
if args.xformers:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
attention_function = attention_pytorch
|
||||
|
||||
if args.xformers:
|
||||
print("Using xformers cross attention")
|
||||
attention_function = attention_xformers
|
||||
else:
|
||||
print("Using pytorch cross attention")
|
||||
attention_function = attention_pytorch
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from backend import args
|
||||
from backend.args import args
|
||||
|
||||
|
||||
def stream_context():
|
||||
|
||||
Reference in New Issue
Block a user