From 68b672493a28d065cf5c7fde5a2a7849bf6e1072 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:46:16 -0600 Subject: [PATCH] implement attention for new backend --- backend/args.py | 7 +---- backend/attention.py | 65 ++++++++++++++++++++++++++++++++++++++++++++ backend/stream.py | 2 +- 3 files changed, 67 insertions(+), 7 deletions(-) create mode 100644 backend/attention.py diff --git a/backend/args.py b/backend/args.py index 5368a25b..1888613b 100644 --- a/backend/args.py +++ b/backend/args.py @@ -28,12 +28,7 @@ fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true") fpte_group.add_argument("--clip-in-fp16", action="store_true") fpte_group.add_argument("--clip-in-fp32", action="store_true") -attn_group = parser.add_mutually_exclusive_group() -attn_group.add_argument("--attention-split", action="store_true") -attn_group.add_argument("--attention-quad", action="store_true") -attn_group.add_argument("--attention-pytorch", action="store_true") - -parser.add_argument("--disable-xformers", action="store_true") +parser.add_argument("--xformers", action="store_true") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1) parser.add_argument("--disable-ipex-hijack", action="store_true") diff --git a/backend/attention.py b/backend/attention.py new file mode 100644 index 00000000..f6cc1ec1 --- /dev/null +++ b/backend/attention.py @@ -0,0 +1,65 @@ +import torch + +from backend.args import args +from einops import rearrange, repeat +from typing import Optional + +if args.xformers: + import xformers + import xformers.ops + + +def attention_xformers(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return out + + +def attention_pytorch(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + + +attention_function = attention_pytorch + +if args.xformers: + print("Using xformers cross attention") + attention_function = attention_xformers +else: + print("Using pytorch cross attention") + attention_function = attention_pytorch diff --git a/backend/stream.py b/backend/stream.py index 3972d0e4..e051d442 100644 --- a/backend/stream.py +++ b/backend/stream.py @@ -1,5 +1,5 @@ import torch -from backend import args +from backend.args import args def stream_context():