make attn func name more technically correct

This commit is contained in:
layerdiffusion
2024-07-31 21:35:19 -07:00
parent 3fb9d5a85c
commit 3079b81547
2 changed files with 7 additions and 7 deletions

View File

@@ -54,7 +54,7 @@ def attention_pytorch(q, k, v, heads, mask=None):
return out
def attention_xformers_single_head(q, k, v):
def attention_xformers_single_head_spatial(q, k, v):
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
@@ -65,7 +65,7 @@ def attention_xformers_single_head(q, k, v):
return out
def attention_pytorch_single_head(q, k, v):
def attention_pytorch_single_head_spatial(q, k, v):
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
@@ -77,16 +77,16 @@ def attention_pytorch_single_head(q, k, v):
attention_function = attention_pytorch
attention_function_single_head = attention_pytorch_single_head
attention_function_single_head_spatial = attention_pytorch_single_head_spatial
if args.xformers:
print("Using xformers cross attention")
attention_function = attention_xformers
attention_function_single_head = attention_xformers_single_head
attention_function_single_head_spatial = attention_xformers_single_head_spatial
else:
print("Using pytorch cross attention")
attention_function = attention_pytorch
attention_function_single_head = attention_pytorch_single_head
attention_function_single_head_spatial = attention_pytorch_single_head_spatial
class AttentionProcessorForge:

View File

@@ -1,7 +1,7 @@
import torch
import numpy as np
from backend.attention import attention_function_single_head
from backend.attention import attention_function_single_head_spatial
from diffusers.configuration_utils import ConfigMixin, register_to_config
from typing import Optional, Tuple
from torch import nn
@@ -181,7 +181,7 @@ class AttnBlock(nn.Module):
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
h_ = attention_function_single_head(q, k, v)
h_ = attention_function_single_head_spatial(q, k, v)
h_ = self.proj_out(h_)
return x + h_