From 3079b81547865acd18751fe79a417a35f464fca8 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:35:19 -0700 Subject: [PATCH] make attn func name more technically correct --- backend/attention.py | 10 +++++----- backend/nn/autoencoder_kl.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/attention.py b/backend/attention.py index 5e891104..776da59f 100644 --- a/backend/attention.py +++ b/backend/attention.py @@ -54,7 +54,7 @@ def attention_pytorch(q, k, v, heads, mask=None): return out -def attention_xformers_single_head(q, k, v): +def attention_xformers_single_head_spatial(q, k, v): B, C, H, W = q.shape q, k, v = map( lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), @@ -65,7 +65,7 @@ def attention_xformers_single_head(q, k, v): return out -def attention_pytorch_single_head(q, k, v): +def attention_pytorch_single_head_spatial(q, k, v): B, C, H, W = q.shape q, k, v = map( lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), @@ -77,16 +77,16 @@ def attention_pytorch_single_head(q, k, v): attention_function = attention_pytorch -attention_function_single_head = attention_pytorch_single_head +attention_function_single_head_spatial = attention_pytorch_single_head_spatial if args.xformers: print("Using xformers cross attention") attention_function = attention_xformers - attention_function_single_head = attention_xformers_single_head + attention_function_single_head_spatial = attention_xformers_single_head_spatial else: print("Using pytorch cross attention") attention_function = attention_pytorch - attention_function_single_head = attention_pytorch_single_head + attention_function_single_head_spatial = attention_pytorch_single_head_spatial class AttentionProcessorForge: diff --git a/backend/nn/autoencoder_kl.py b/backend/nn/autoencoder_kl.py index 5e830fc8..5b081f8d 100644 --- a/backend/nn/autoencoder_kl.py +++ b/backend/nn/autoencoder_kl.py @@ -1,7 +1,7 @@ import torch import numpy as np -from backend.attention import attention_function_single_head +from backend.attention import attention_function_single_head_spatial from diffusers.configuration_utils import ConfigMixin, register_to_config from typing import Optional, Tuple from torch import nn @@ -181,7 +181,7 @@ class AttnBlock(nn.Module): q = self.q(h_) k = self.k(h_) v = self.v(h_) - h_ = attention_function_single_head(q, k, v) + h_ = attention_function_single_head_spatial(q, k, v) h_ = self.proj_out(h_) return x + h_