From 3dd084d55ba633d3b2f133778efd28ccd8428745 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 2 Aug 2024 03:32:19 -0700 Subject: [PATCH] more attention types --- backend/args.py | 13 +- backend/attention.py | 448 ++++++++++++++++++++++-- backend/misc/sub_quadratic_attention.py | 273 +++++++++++++++ modules_forge/initialization.py | 8 +- modules_forge/patch_basic.py | 2 +- 5 files changed, 699 insertions(+), 45 deletions(-) create mode 100644 backend/misc/sub_quadratic_attention.py diff --git a/backend/args.py b/backend/args.py index 1888613b..8c8e1cbf 100644 --- a/backend/args.py +++ b/backend/args.py @@ -3,7 +3,6 @@ import argparse parser = argparse.ArgumentParser() parser.add_argument("--gpu-device-id", type=int, default=None, metavar="DEVICE_ID") -parser.add_argument("--disable-attention-upcast", action="store_true") fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--all-in-fp32", action="store_true") @@ -28,7 +27,17 @@ fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true") fpte_group.add_argument("--clip-in-fp16", action="store_true") fpte_group.add_argument("--clip-in-fp32", action="store_true") -parser.add_argument("--xformers", action="store_true") +attn_group = parser.add_mutually_exclusive_group() +attn_group.add_argument("--attention-split", action="store_true") +attn_group.add_argument("--attention-quad", action="store_true") +attn_group.add_argument("--attention-pytorch", action="store_true") + +upcast = parser.add_mutually_exclusive_group() +upcast.add_argument("--force-upcast-attention", action="store_true") +upcast.add_argument("--disable-attention-upcast", action="store_true") + +parser.add_argument("--disable-xformers", action="store_true") + parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1) parser.add_argument("--disable-ipex-hijack", action="store_true") diff --git a/backend/attention.py b/backend/attention.py index 776da59f..b72dc07e 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -1,33 +1,89 @@ +import math import torch +import einops from backend.args import args +from backend import memory_management +from backend.misc.sub_quadratic_attention import efficient_dot_product_attention -if args.xformers: + +BROKEN_XFORMERS = False +if memory_management.xformers_enabled(): import xformers import xformers.ops + try: + x_vers = xformers.__version__ + BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") + except: + pass -def attention_xformers(q, k, v, heads, mask=None): - b, _, dim_head = q.shape - dim_head //= heads - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, -1, heads, dim_head) - .permute(0, 2, 1, 3) - .reshape(b * heads, -1, dim_head) - .contiguous(), - (q, k, v), - ) +FORCE_UPCAST_ATTENTION_DTYPE = memory_management.force_upcast_attention_dtype() - if mask is not None: - pad = 8 - q.shape[1] % 8 - mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) - mask_out[:, :, :mask.shape[-1]] = mask - mask = mask_out[:, :, :mask.shape[-1]] - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) +def get_attn_precision(attn_precision): + if args.disable_attention_upcast: + return None + if FORCE_UPCAST_ATTENTION_DTYPE is not None: + return FORCE_UPCAST_ATTENTION_DTYPE + return attn_precision + +def exists(val): + return val is not None + + +def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + attn_precision = get_attn_precision(attn_precision) + + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + + scale = dim_head ** -0.5 + + h = heads + if skip_reshape: + q, k, v = map( + lambda t: t.reshape(b * heads, -1, dim_head), + (q, k, v), + ) + else: + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + if attn_precision == torch.float32: + sim = torch.einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if exists(mask): + if mask.dtype == torch.bool: + mask = einops.rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = einops.repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + else: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + sim.add_(mask) + + sim = sim.softmax(dim=-1) + out = torch.einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) @@ -37,56 +93,372 @@ def attention_xformers(q, k, v, heads, mask=None): return out -def attention_pytorch(q, k, v, heads, mask=None): - b, _, dim_head = q.shape - dim_head //= heads +def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): + attn_precision = get_attn_precision(attn_precision) - q, k, v = map( - lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), - (q, k, v), + if skip_reshape: + b, _, _, dim_head = query.shape + else: + b, _, dim_head = query.shape + dim_head //= heads + + scale = dim_head ** -0.5 + + if skip_reshape: + query = query.reshape(b * heads, -1, dim_head) + value = value.reshape(b * heads, -1, dim_head) + key = key.reshape(b * heads, -1, dim_head).movedim(1, 2) + else: + query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) + + dtype = query.dtype + upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 + if upcast_attention: + bytes_per_token = torch.finfo(torch.float32).bits // 8 + else: + bytes_per_token = torch.finfo(query.dtype).bits // 8 + batch_x_heads, q_tokens, _ = query.shape + _, _, k_tokens = key.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + mem_free_total, mem_free_torch = memory_management.get_free_memory(query.device, True) + + kv_chunk_size_min = None + kv_chunk_size = None + query_chunk_size = None + + for x in [4096, 2048, 1024, 512, 256]: + count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0) + if count >= k_tokens: + kv_chunk_size = k_tokens + query_chunk_size = x + break + + if query_chunk_size is None: + query_chunk_size = 512 + + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + + hidden_states = efficient_dot_product_attention( + query, + key, + value, + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=kv_chunk_size_min, + use_checkpoint=False, + upcast_attention=upcast_attention, + mask=mask, ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.to(dtype) + hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2) + return hidden_states + + +def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + attn_precision = get_attn_precision(attn_precision) + + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + + scale = dim_head ** -0.5 + + h = heads + if skip_reshape: + q, k, v = map( + lambda t: t.reshape(b * heads, -1, dim_head), + (q, k, v), + ) + else: + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = memory_management.get_free_memory(q.device) + + if attn_precision == torch.float32: + element_size = 4 + upcast = True + else: + element_size = q.element_size() + upcast = False + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size + modifier = 3 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) + first_op_done = False + cleared_cache = False + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + if upcast: + with torch.autocast(enabled=False, device_type='cuda'): + s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale + else: + s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + + s2 = s1.softmax(dim=-1).to(v.dtype) + del s1 + first_op_done = True + + r1[:, i:end] = torch.einsum('b i j, b j d -> b i d', s2, v) + del s2 + break + except memory_management.OOM_EXCEPTION as e: + if first_op_done == False: + memory_management.soft_empty_cache(True) + if cleared_cache == False: + cleared_cache = True + print("out of memory error, emptying cache and trying again") + continue + steps *= 2 + if steps > 64: + raise e + print("out of memory error, increasing steps and trying again {}".format(steps)) + else: + raise e + + del q, k, v + + r1 = ( + r1.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return r1 + + +def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + + disabled_xformers = False + + if BROKEN_XFORMERS: + if b * heads > 65535: + disabled_xformers = True + + if not disabled_xformers: + if torch.jit.is_tracing() or torch.jit.is_scripting(): + disabled_xformers = True + + if disabled_xformers: + return attention_pytorch(q, k, v, heads, mask) + + if skip_reshape: + q, k, v = map( + lambda t: t.reshape(b * heads, -1, dim_head), + (q, k, v), + ) + else: + q, k, v = map( + lambda t: t.reshape(b, -1, heads, dim_head), + (q, k, v), + ) + + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) + + if skip_reshape: + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + else: + out = ( + out.reshape(b, -1, heads * dim_head) + ) + + return out + + +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) return out -def attention_xformers_single_head_spatial(q, k, v): +def slice_attention_single_head_spatial(q, k, v): + r1 = torch.zeros_like(k, device=q.device) + scale = (int(q.shape[-1]) ** (-0.5)) + + mem_free_total = memory_management.get_free_memory(q.device) + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = torch.bmm(q[:, i:end], k) * scale + + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0, 2, 1) + del s1 + + r1[:, :, i:end] = torch.bmm(v, s2) + del s2 + break + except memory_management.OOM_EXCEPTION as e: + memory_management.soft_empty_cache(True) + steps *= 2 + if steps > 128: + raise e + print("out of memory error, increasing steps and trying again {}".format(steps)) + + return r1 + + +def normal_attention_single_head_spatial(q, k, v): + # compute attention + b, c, h, w = q.shape + + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + v = v.reshape(b, c, h * w) + + r1 = slice_attention_single_head_spatial(q, k, v) + h_ = r1.reshape(b, c, h, w) + del r1 + return h_ + + +def xformers_attention_single_head_spatial(q, k, v): + # compute attention B, C, H, W = q.shape q, k, v = map( lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), (q, k, v), ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) - out = out.transpose(1, 2).reshape(B, C, H, W) + + try: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + out = out.transpose(1, 2).reshape(B, C, H, W) + except NotImplementedError as e: + out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), + v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out -def attention_pytorch_single_head_spatial(q, k, v): +def pytorch_attention_single_head_spatial(q, k, v): + # compute attention B, C, H, W = q.shape q, k, v = map( lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(2, 3).reshape(B, C, H, W) + + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except memory_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), + v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) return out -attention_function = attention_pytorch -attention_function_single_head_spatial = attention_pytorch_single_head_spatial - -if args.xformers: +if memory_management.xformers_enabled(): print("Using xformers cross attention") attention_function = attention_xformers - attention_function_single_head_spatial = attention_xformers_single_head_spatial -else: +elif memory_management.pytorch_attention_enabled(): print("Using pytorch cross attention") attention_function = attention_pytorch - attention_function_single_head_spatial = attention_pytorch_single_head_spatial +elif args.attention_split: + print("Using split optimization for cross attention") + attention_function = attention_split +else: + print("Using sub quadratic optimization for cross attention") + attention_function = attention_sub_quad + +if memory_management.xformers_enabled_vae(): + print("Using xformers attention for VAE") + attention_function_single_head_spatial = xformers_attention_single_head_spatial +elif memory_management.pytorch_attention_enabled(): + print("Using pytorch attention for VAE") + attention_function_single_head_spatial = pytorch_attention_single_head_spatial +else: + print("Using split attention for VAE") + attention_function_single_head_spatial = normal_attention_single_head_spatial class AttentionProcessorForge: diff --git a/backend/misc/sub_quadratic_attention.py b/backend/misc/sub_quadratic_attention.py new file mode 100644 index 00000000..9f4c23c7 --- /dev/null +++ b/backend/misc/sub_quadratic_attention.py @@ -0,0 +1,273 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# MIT +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math + +try: + from typing import Optional, NamedTuple, List, Protocol +except ImportError: + from typing import Optional, NamedTuple, List + from typing_extensions import Protocol + +from torch import Tensor +from typing import List + +from ldm_patched.modules import model_management + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key_t: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key_t: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key_t: Tensor, + value: Tensor, + scale: float, + upcast_attention: bool, + mask, +) -> AttnChunk: + if upcast_attention: + with torch.autocast(enabled=False, device_type = 'cuda'): + query = query.float() + key_t = key_t.float() + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + else: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + attn_weights -= max_score + if mask is not None: + attn_weights += mask + torch.exp(attn_weights, out=attn_weights) + exp_weights = attn_weights.to(value.dtype) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key_t: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, + mask, +) -> Tensor: + batch_x_heads, k_channels_per_head, k_tokens = key_t.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int, mask) -> AttnChunk: + key_chunk = dynamic_slice( + key_t, + (0, 0, chunk_idx), + (batch_x_heads, k_channels_per_head, kv_chunk_size) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + if mask is not None: + mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + + return summarize_chunk(query, key_chunk, value_chunk, mask=mask) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key_t: Tensor, + value: Tensor, + scale: float, + upcast_attention: bool, + mask, +) -> Tensor: + if upcast_attention: + with torch.autocast(enabled=False, device_type = 'cuda'): + query = query.float() + key_t = key_t.float() + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + else: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key_t, + alpha=scale, + beta=0, + ) + + if mask is not None: + attn_scores += mask + try: + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + except model_management.OOM_EXCEPTION: + print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") + attn_scores -= attn_scores.max(dim=-1, keepdim=True).values + torch.exp(attn_scores, out=attn_scores) + summed = torch.sum(attn_scores, dim=-1, keepdim=True) + attn_scores /= summed + attn_probs = attn_scores + + hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key_t: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, + upcast_attention=False, + mask = None, +): + """Computes efficient dot-product attention given query, transposed key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key_t: keys for calculating attention with shape of + `[batch * num_heads, channels_per_head, tokens]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, _, k_tokens = key_t.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + if mask is not None and len(mask.shape) == 2: + mask = mask.unsqueeze(0) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + def get_mask_chunk(chunk_idx: int) -> Tensor: + if mask is None: + return None + chunk = min(query_chunk_size, q_tokens) + return mask[:,chunk_idx:chunk_idx + chunk] + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale, + upcast_attention=upcast_attention + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key_t=key_t, + value=value, + mask=mask, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key_t=key_t, + value=value, + mask=get_mask_chunk(i * query_chunk_size) + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 2a37bcb8..4fee214a 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -47,19 +47,19 @@ def initialize_forge(): from modules_forge.cuda_malloc import try_cuda_malloc try_cuda_malloc() - import ldm_patched.modules.model_management as model_management + from backend import memory_management import torch monitor_module_moving() - device = model_management.get_torch_device() + device = memory_management.get_torch_device() torch.zeros((1, 1)).to(device, torch.float32) - model_management.soft_empty_cache() + memory_management.soft_empty_cache() import modules_forge.patch_basic modules_forge.patch_basic.patch_all_basics() - from modules_forge import stream + from backend import stream print('CUDA Stream Activated: ', stream.using_stream) from modules_forge.shared import diffusers_dir diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index 5b61819a..bb530325 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -1,6 +1,6 @@ import torch import os -import safetensors +import safetensors.torch def build_loaded(module, loader_name):