diff --git a/backend/misc/sub_quadratic_attention.py b/backend/misc/sub_quadratic_attention.py index 9f4c23c7..687b6a3a 100644 --- a/backend/misc/sub_quadratic_attention.py +++ b/backend/misc/sub_quadratic_attention.py @@ -16,55 +16,60 @@ from torch.utils.checkpoint import checkpoint import math try: - from typing import Optional, NamedTuple, List, Protocol + from typing import Optional, NamedTuple, List, Protocol except ImportError: - from typing import Optional, NamedTuple, List - from typing_extensions import Protocol + from typing import Optional, NamedTuple, List + from typing_extensions import Protocol from torch import Tensor from typing import List -from ldm_patched.modules import model_management +from backend import memory_management + def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], + x: Tensor, + starts: List[int], + sizes: List[int], ) -> Tensor: slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] return x[slicing] + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor + class SummarizeChunk(Protocol): @staticmethod def __call__( - query: Tensor, - key_t: Tensor, - value: Tensor, + query: Tensor, + key_t: Tensor, + value: Tensor, ) -> AttnChunk: ... + class ComputeQueryChunkAttn(Protocol): @staticmethod def __call__( + query: Tensor, + key_t: Tensor, + value: Tensor, + ) -> Tensor: ... + + +def _summarize_chunk( query: Tensor, key_t: Tensor, value: Tensor, - ) -> Tensor: ... - -def _summarize_chunk( - query: Tensor, - key_t: Tensor, - value: Tensor, - scale: float, - upcast_attention: bool, - mask, + scale: float, + upcast_attention: bool, + mask, ) -> AttnChunk: if upcast_attention: - with torch.autocast(enabled=False, device_type = 'cuda'): + with torch.autocast(enabled=False, device_type='cuda'): query = query.float() key_t = key_t.float() attn_weights = torch.baddbmm( @@ -93,13 +98,14 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( - query: Tensor, - key_t: Tensor, - value: Tensor, - summarize_chunk: SummarizeChunk, - kv_chunk_size: int, - mask, + query: Tensor, + key_t: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, + mask, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape @@ -116,7 +122,7 @@ def _query_chunk_attention( (batch_x_heads, kv_chunk_size, v_channels_per_head) ) if mask is not None: - mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + mask = mask[:, :, chunk_idx:chunk_idx + kv_chunk_size] return summarize_chunk(query, key_chunk, value_chunk, mask=mask) @@ -135,17 +141,18 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( - query: Tensor, - key_t: Tensor, - value: Tensor, - scale: float, - upcast_attention: bool, - mask, + query: Tensor, + key_t: Tensor, + value: Tensor, + scale: float, + upcast_attention: bool, + mask, ) -> Tensor: if upcast_attention: - with torch.autocast(enabled=False, device_type = 'cuda'): + with torch.autocast(enabled=False, device_type='cuda'): query = query.float() key_t = key_t.float() attn_scores = torch.baddbmm( @@ -169,7 +176,7 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except memory_management.OOM_EXCEPTION: print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) @@ -180,20 +187,22 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( - query: Tensor, - key_t: Tensor, - value: Tensor, - query_chunk_size=1024, - kv_chunk_size: Optional[int] = None, - kv_chunk_size_min: Optional[int] = None, - use_checkpoint=True, - upcast_attention=False, - mask = None, + query: Tensor, + key_t: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, + upcast_attention=False, + mask=None, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -234,7 +243,7 @@ def efficient_dot_product_attention( if mask is None: return None chunk = min(query_chunk_size, q_tokens) - return mask[:,chunk_idx:chunk_idx + chunk] + return mask[:, chunk_idx:chunk_idx + chunk] summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk @@ -259,7 +268,7 @@ def efficient_dot_product_attention( value=value, mask=mask, ) - + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # and pass slices to be mutated, instead of torch.cat()ing the returned slices res = torch.cat([