Update sub_quadratic_attention.py

This commit is contained in:
layerdiffusion
2024-08-03 15:19:45 -07:00
parent 4add428e25
commit c7b1789892

View File

@@ -24,7 +24,8 @@ except ImportError:
from torch import Tensor
from typing import List
from ldm_patched.modules import model_management
from backend import memory_management
def dynamic_slice(
x: Tensor,
@@ -34,11 +35,13 @@ def dynamic_slice(
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
return x[slicing]
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
class SummarizeChunk(Protocol):
@staticmethod
def __call__(
@@ -47,6 +50,7 @@ class SummarizeChunk(Protocol):
value: Tensor,
) -> AttnChunk: ...
class ComputeQueryChunkAttn(Protocol):
@staticmethod
def __call__(
@@ -55,6 +59,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor,
) -> Tensor: ...
def _summarize_chunk(
query: Tensor,
key_t: Tensor,
@@ -93,6 +98,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
def _query_chunk_attention(
query: Tensor,
key_t: Tensor,
@@ -135,6 +141,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
@@ -169,7 +176,7 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except memory_management.OOM_EXCEPTION:
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
torch.exp(attn_scores, out=attn_scores)
@@ -180,10 +187,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
def efficient_dot_product_attention(
query: Tensor,
key_t: Tensor,