mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 02:31:16 +00:00
Update sub_quadratic_attention.py
This commit is contained in:
@@ -16,55 +16,60 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Optional, NamedTuple, List, Protocol
|
from typing import Optional, NamedTuple, List, Protocol
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing import Optional, NamedTuple, List
|
from typing import Optional, NamedTuple, List
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management
|
||||||
|
|
||||||
|
|
||||||
def dynamic_slice(
|
def dynamic_slice(
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
starts: List[int],
|
starts: List[int],
|
||||||
sizes: List[int],
|
sizes: List[int],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
|
||||||
return x[slicing]
|
return x[slicing]
|
||||||
|
|
||||||
|
|
||||||
class AttnChunk(NamedTuple):
|
class AttnChunk(NamedTuple):
|
||||||
exp_values: Tensor
|
exp_values: Tensor
|
||||||
exp_weights_sum: Tensor
|
exp_weights_sum: Tensor
|
||||||
max_score: Tensor
|
max_score: Tensor
|
||||||
|
|
||||||
|
|
||||||
class SummarizeChunk(Protocol):
|
class SummarizeChunk(Protocol):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __call__(
|
def __call__(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
) -> AttnChunk: ...
|
) -> AttnChunk: ...
|
||||||
|
|
||||||
|
|
||||||
class ComputeQueryChunkAttn(Protocol):
|
class ComputeQueryChunkAttn(Protocol):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __call__(
|
def __call__(
|
||||||
|
query: Tensor,
|
||||||
|
key_t: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_chunk(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
) -> Tensor: ...
|
scale: float,
|
||||||
|
upcast_attention: bool,
|
||||||
def _summarize_chunk(
|
mask,
|
||||||
query: Tensor,
|
|
||||||
key_t: Tensor,
|
|
||||||
value: Tensor,
|
|
||||||
scale: float,
|
|
||||||
upcast_attention: bool,
|
|
||||||
mask,
|
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type='cuda'):
|
||||||
query = query.float()
|
query = query.float()
|
||||||
key_t = key_t.float()
|
key_t = key_t.float()
|
||||||
attn_weights = torch.baddbmm(
|
attn_weights = torch.baddbmm(
|
||||||
@@ -93,13 +98,14 @@ def _summarize_chunk(
|
|||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
|
|
||||||
|
|
||||||
def _query_chunk_attention(
|
def _query_chunk_attention(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
summarize_chunk: SummarizeChunk,
|
summarize_chunk: SummarizeChunk,
|
||||||
kv_chunk_size: int,
|
kv_chunk_size: int,
|
||||||
mask,
|
mask,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
batch_x_heads, k_channels_per_head, k_tokens = key_t.shape
|
||||||
_, _, v_channels_per_head = value.shape
|
_, _, v_channels_per_head = value.shape
|
||||||
@@ -116,7 +122,7 @@ def _query_chunk_attention(
|
|||||||
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
(batch_x_heads, kv_chunk_size, v_channels_per_head)
|
||||||
)
|
)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size]
|
mask = mask[:, :, chunk_idx:chunk_idx + kv_chunk_size]
|
||||||
|
|
||||||
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
return summarize_chunk(query, key_chunk, value_chunk, mask=mask)
|
||||||
|
|
||||||
@@ -135,17 +141,18 @@ def _query_chunk_attention(
|
|||||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||||
return all_values / all_weights
|
return all_values / all_weights
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||||
def _get_attention_scores_no_kv_chunking(
|
def _get_attention_scores_no_kv_chunking(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
upcast_attention: bool,
|
upcast_attention: bool,
|
||||||
mask,
|
mask,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if upcast_attention:
|
if upcast_attention:
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
with torch.autocast(enabled=False, device_type='cuda'):
|
||||||
query = query.float()
|
query = query.float()
|
||||||
key_t = key_t.float()
|
key_t = key_t.float()
|
||||||
attn_scores = torch.baddbmm(
|
attn_scores = torch.baddbmm(
|
||||||
@@ -169,7 +176,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
try:
|
try:
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
except model_management.OOM_EXCEPTION:
|
except memory_management.OOM_EXCEPTION:
|
||||||
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||||
torch.exp(attn_scores, out=attn_scores)
|
torch.exp(attn_scores, out=attn_scores)
|
||||||
@@ -180,20 +187,22 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
|
||||||
return hidden_states_slice
|
return hidden_states_slice
|
||||||
|
|
||||||
|
|
||||||
class ScannedChunk(NamedTuple):
|
class ScannedChunk(NamedTuple):
|
||||||
chunk_idx: int
|
chunk_idx: int
|
||||||
attn_chunk: AttnChunk
|
attn_chunk: AttnChunk
|
||||||
|
|
||||||
|
|
||||||
def efficient_dot_product_attention(
|
def efficient_dot_product_attention(
|
||||||
query: Tensor,
|
query: Tensor,
|
||||||
key_t: Tensor,
|
key_t: Tensor,
|
||||||
value: Tensor,
|
value: Tensor,
|
||||||
query_chunk_size=1024,
|
query_chunk_size=1024,
|
||||||
kv_chunk_size: Optional[int] = None,
|
kv_chunk_size: Optional[int] = None,
|
||||||
kv_chunk_size_min: Optional[int] = None,
|
kv_chunk_size_min: Optional[int] = None,
|
||||||
use_checkpoint=True,
|
use_checkpoint=True,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
mask = None,
|
mask=None,
|
||||||
):
|
):
|
||||||
"""Computes efficient dot-product attention given query, transposed key, and value.
|
"""Computes efficient dot-product attention given query, transposed key, and value.
|
||||||
This is efficient version of attention presented in
|
This is efficient version of attention presented in
|
||||||
@@ -234,7 +243,7 @@ def efficient_dot_product_attention(
|
|||||||
if mask is None:
|
if mask is None:
|
||||||
return None
|
return None
|
||||||
chunk = min(query_chunk_size, q_tokens)
|
chunk = min(query_chunk_size, q_tokens)
|
||||||
return mask[:,chunk_idx:chunk_idx + chunk]
|
return mask[:, chunk_idx:chunk_idx + chunk]
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention)
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
@@ -259,7 +268,7 @@ def efficient_dot_product_attention(
|
|||||||
value=value,
|
value=value,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||||
res = torch.cat([
|
res = torch.cat([
|
||||||
|
|||||||
Reference in New Issue
Block a user