mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-08 08:59:58 +00:00
make attn func name more technically correct
This commit is contained in:
@@ -54,7 +54,7 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def attention_xformers_single_head(q, k, v):
|
||||
def attention_xformers_single_head_spatial(q, k, v):
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
@@ -65,7 +65,7 @@ def attention_xformers_single_head(q, k, v):
|
||||
return out
|
||||
|
||||
|
||||
def attention_pytorch_single_head(q, k, v):
|
||||
def attention_pytorch_single_head_spatial(q, k, v):
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
@@ -77,16 +77,16 @@ def attention_pytorch_single_head(q, k, v):
|
||||
|
||||
|
||||
attention_function = attention_pytorch
|
||||
attention_function_single_head = attention_pytorch_single_head
|
||||
attention_function_single_head_spatial = attention_pytorch_single_head_spatial
|
||||
|
||||
if args.xformers:
|
||||
print("Using xformers cross attention")
|
||||
attention_function = attention_xformers
|
||||
attention_function_single_head = attention_xformers_single_head
|
||||
attention_function_single_head_spatial = attention_xformers_single_head_spatial
|
||||
else:
|
||||
print("Using pytorch cross attention")
|
||||
attention_function = attention_pytorch
|
||||
attention_function_single_head = attention_pytorch_single_head
|
||||
attention_function_single_head_spatial = attention_pytorch_single_head_spatial
|
||||
|
||||
|
||||
class AttentionProcessorForge:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from backend.attention import attention_function_single_head
|
||||
from backend.attention import attention_function_single_head_spatial
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from typing import Optional, Tuple
|
||||
from torch import nn
|
||||
@@ -181,7 +181,7 @@ class AttnBlock(nn.Module):
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
h_ = attention_function_single_head(q, k, v)
|
||||
h_ = attention_function_single_head_spatial(q, k, v)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
Reference in New Issue
Block a user