mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-05-01 11:41:23 +00:00
Update sub_quadratic_attention.py
This commit is contained in:
@@ -24,7 +24,8 @@ except ImportError:
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
from ldm_patched.modules import model_management
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
def dynamic_slice(
|
||||
x: Tensor,
|
||||
@@ -34,11 +35,13 @@ def dynamic_slice(
|
||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
||||
return x[slicing]
|
||||
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
@@ -47,6 +50,7 @@ class SummarizeChunk(Protocol):
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
@@ -55,6 +59,7 @@ class ComputeQueryChunkAttn(Protocol):
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
@@ -93,6 +98,7 @@ def _summarize_chunk(
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
@@ -135,6 +141,7 @@ def _query_chunk_attention(
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
@@ -169,7 +176,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except memory_management.OOM_EXCEPTION:
|
||||
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
@@ -180,10 +187,12 @@ def _get_attention_scores_no_kv_chunking(
|
||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||
return hidden_states_slice
|
||||
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key_t: Tensor,
|
||||
|
||||
Reference in New Issue
Block a user