mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-22 15:53:58 +00:00
more attention types
This commit is contained in:
@@ -3,7 +3,6 @@ import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--gpu-device-id", type=int, default=None, metavar="DEVICE_ID")
|
||||
parser.add_argument("--disable-attention-upcast", action="store_true")
|
||||
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--all-in-fp32", action="store_true")
|
||||
@@ -28,7 +27,17 @@ fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp16", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp32", action="store_true")
|
||||
|
||||
parser.add_argument("--xformers", action="store_true")
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--attention-split", action="store_true")
|
||||
attn_group.add_argument("--attention-quad", action="store_true")
|
||||
attn_group.add_argument("--attention-pytorch", action="store_true")
|
||||
|
||||
upcast = parser.add_mutually_exclusive_group()
|
||||
upcast.add_argument("--force-upcast-attention", action="store_true")
|
||||
upcast.add_argument("--disable-attention-upcast", action="store_true")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1)
|
||||
parser.add_argument("--disable-ipex-hijack", action="store_true")
|
||||
|
||||
|
||||
@@ -1,33 +1,89 @@
|
||||
import math
|
||||
import torch
|
||||
import einops
|
||||
|
||||
from backend.args import args
|
||||
from backend import memory_management
|
||||
from backend.misc.sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
if args.xformers:
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
if memory_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
try:
|
||||
x_vers = xformers.__version__
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = memory_management.force_upcast_attention_dtype()
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
def get_attn_precision(attn_precision):
|
||||
if args.disable_attention_upcast:
|
||||
return None
|
||||
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||
return attn_precision
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if attn_precision == torch.float32:
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
if mask.dtype == torch.bool:
|
||||
mask = einops.rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = einops.repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
sim.add_(mask)
|
||||
|
||||
sim = sim.softmax(dim=-1)
|
||||
out = torch.einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
@@ -37,56 +93,372 @@ def attention_xformers(q, k, v, heads, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = query.shape
|
||||
else:
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
|
||||
else:
|
||||
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits // 8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits // 8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
mem_free_total, mem_free_torch = memory_management.get_free_memory(query.device, True)
|
||||
|
||||
kv_chunk_size_min = None
|
||||
kv_chunk_size = None
|
||||
query_chunk_size = None
|
||||
|
||||
for x in [4096, 2048, 1024, 512, 256]:
|
||||
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
|
||||
if count >= k_tokens:
|
||||
kv_chunk_size = k_tokens
|
||||
query_chunk_size = x
|
||||
break
|
||||
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=False,
|
||||
upcast_attention=upcast_attention,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = memory_management.get_free_memory(q.device)
|
||||
|
||||
if attn_precision == torch.float32:
|
||||
element_size = 4
|
||||
upcast = True
|
||||
else:
|
||||
element_size = q.element_size()
|
||||
upcast = False
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||
modifier = 3
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
bs = 1
|
||||
else:
|
||||
bs = mask.shape[0]
|
||||
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if upcast:
|
||||
with torch.autocast(enabled=False, device_type='cuda'):
|
||||
s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
s1 += mask[i:end]
|
||||
else:
|
||||
s1 += mask[:, i:end]
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
first_op_done = True
|
||||
|
||||
r1[:, i:end] = torch.einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
memory_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again {}".format(steps))
|
||||
else:
|
||||
raise e
|
||||
|
||||
del q, k, v
|
||||
|
||||
r1 = (
|
||||
r1.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return r1
|
||||
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
disabled_xformers = True
|
||||
|
||||
if not disabled_xformers:
|
||||
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
disabled_xformers = True
|
||||
|
||||
if disabled_xformers:
|
||||
return attention_pytorch(q, k, v, heads, mask)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
else:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
pad = 8 - q.shape[1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
|
||||
mask_out[:, :, :mask.shape[-1]] = mask
|
||||
mask = mask_out[:, :, :mask.shape[-1]]
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||
|
||||
if skip_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def attention_xformers_single_head_spatial(q, k, v):
|
||||
def slice_attention_single_head_spatial(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
scale = (int(q.shape[-1]) ** (-0.5))
|
||||
|
||||
mem_free_total = memory_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0, 2, 1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
memory_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again {}".format(steps))
|
||||
|
||||
return r1
|
||||
|
||||
|
||||
def normal_attention_single_head_spatial(q, k, v):
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
v = v.reshape(b, c, h * w)
|
||||
|
||||
r1 = slice_attention_single_head_spatial(q, k, v)
|
||||
h_ = r1.reshape(b, c, h, w)
|
||||
del r1
|
||||
return h_
|
||||
|
||||
|
||||
def xformers_attention_single_head_spatial(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2),
|
||||
v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
def attention_pytorch_single_head_spatial(q, k, v):
|
||||
def pytorch_attention_single_head_spatial(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2),
|
||||
v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
attention_function = attention_pytorch
|
||||
attention_function_single_head_spatial = attention_pytorch_single_head_spatial
|
||||
|
||||
if args.xformers:
|
||||
if memory_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
attention_function = attention_xformers
|
||||
attention_function_single_head_spatial = attention_xformers_single_head_spatial
|
||||
else:
|
||||
elif memory_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
attention_function = attention_pytorch
|
||||
attention_function_single_head_spatial = attention_pytorch_single_head_spatial
|
||||
elif args.attention_split:
|
||||
print("Using split optimization for cross attention")
|
||||
attention_function = attention_split
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention")
|
||||
attention_function = attention_sub_quad
|
||||
|
||||
if memory_management.xformers_enabled_vae():
|
||||
print("Using xformers attention for VAE")
|
||||
attention_function_single_head_spatial = xformers_attention_single_head_spatial
|
||||
elif memory_management.pytorch_attention_enabled():
|
||||
print("Using pytorch attention for VAE")
|
||||
attention_function_single_head_spatial = pytorch_attention_single_head_spatial
|
||||
else:
|
||||
print("Using split attention for VAE")
|
||||
attention_function_single_head_spatial = normal_attention_single_head_spatial
|
||||
|
||||
|
||||
class AttentionProcessorForge:
|
||||
|
||||
273
backend/misc/sub_quadratic_attention.py
Normal file
273
backend/misc/sub_quadratic_attention.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# original source:
|
||||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||
# license:
|
||||
# MIT
|
||||
# credit:
|
||||
# Amin Rezaei (original author)
|
||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||
# implementation of:
|
||||
# Self-attention Does Not Need O(n2) Memory":
|
||||
# https://arxiv.org/abs/2112.05682v2
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
|
||||
try:
|
||||
from typing import Optional, NamedTuple, List, Protocol
|
||||
except ImportError:
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
def dynamic_slice(
|
||||
x: Tensor,
|
||||
starts: List[int],
|
||||
sizes: List[int],
|
||||
) -> Tensor:
|
||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
||||
return x[slicing]
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> AttnChunk:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
query = query.float()
|
||||
key_t = key_t.float()
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
else:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
attn_weights -= max_score
|
||||
if mask is not None:
|
||||
attn_weights += mask
|
||||
torch.exp(attn_weights, out=attn_weights)
|
||||
exp_weights = attn_weights.to(value.dtype)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int, mask) -> AttnChunk:
|
||||
key_chunk = dynamic_slice(
|
||||
key_t,
|
||||
(0, 0, chunk_idx),
|
||||
(batch_x_heads, k_channels_per_head, kv_chunk_size)
|
||||
)
|
||||
value_chunk = dynamic_slice(
|
||||
value,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||
)
|
||||
if mask is not None:
|
||||
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
||||
|
||||
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
|
||||
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = chunk_values.sum(dim=0)
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
upcast_attention: bool,
|
||||
mask,
|
||||
) -> Tensor:
|
||||
if upcast_attention:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
query = query.float()
|
||||
key_t = key_t.float()
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
else:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key_t,
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
attn_scores += mask
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||
attn_scores /= summed
|
||||
attn_probs = attn_scores
|
||||
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
value: Tensor,
|
||||
query_chunk_size=1024,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
upcast_attention=False,
|
||||
mask = None,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||
This is efficient version of attention presented in
|
||||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||
Args:
|
||||
query: queries for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
key_t: keys for calculating attention with shape of
|
||||
`[batch * num_heads, channels_per_head, tokens]`.
|
||||
value: values to be used in attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
query_chunk_size: int: query chunks size
|
||||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||
Returns:
|
||||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||
"""
|
||||
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
scale = q_channels_per_head ** -0.5
|
||||
|
||||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
if mask is not None and len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return dynamic_slice(
|
||||
query,
|
||||
(0, chunk_idx, 0),
|
||||
(batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
|
||||
)
|
||||
|
||||
def get_mask_chunk(chunk_idx: int) -> Tensor:
|
||||
if mask is None:
|
||||
return None
|
||||
chunk = min(query_chunk_size, q_tokens)
|
||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
_get_attention_scores_no_kv_chunking,
|
||||
scale=scale,
|
||||
upcast_attention=upcast_attention
|
||||
) if k_tokens <= kv_chunk_size else (
|
||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||
partial(
|
||||
_query_chunk_attention,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
summarize_chunk=summarize_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
if q_tokens <= query_chunk_size:
|
||||
# fast-path for when there's just 1 query chunk
|
||||
return compute_query_chunk_attn(
|
||||
query=query,
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key_t=key_t,
|
||||
value=value,
|
||||
mask=get_mask_chunk(i * query_chunk_size)
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
||||
@@ -47,19 +47,19 @@ def initialize_forge():
|
||||
from modules_forge.cuda_malloc import try_cuda_malloc
|
||||
try_cuda_malloc()
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
from backend import memory_management
|
||||
import torch
|
||||
|
||||
monitor_module_moving()
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
device = memory_management.get_torch_device()
|
||||
torch.zeros((1, 1)).to(device, torch.float32)
|
||||
model_management.soft_empty_cache()
|
||||
memory_management.soft_empty_cache()
|
||||
|
||||
import modules_forge.patch_basic
|
||||
modules_forge.patch_basic.patch_all_basics()
|
||||
|
||||
from modules_forge import stream
|
||||
from backend import stream
|
||||
print('CUDA Stream Activated: ', stream.using_stream)
|
||||
|
||||
from modules_forge.shared import diffusers_dir
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import os
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
def build_loaded(module, loader_name):
|
||||
|
||||
Reference in New Issue
Block a user