mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 10:41:25 +00:00
format
This commit is contained in:
@@ -108,9 +108,9 @@ class QKNorm(nn.Module):
|
|||||||
self.key_norm = RMSNorm(dim)
|
self.key_norm = RMSNorm(dim)
|
||||||
|
|
||||||
def forward(self, q, k, v):
|
def forward(self, q, k, v):
|
||||||
|
del v
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
k = self.key_norm(k)
|
k = self.key_norm(k)
|
||||||
del v
|
|
||||||
return q.to(k), k.to(q)
|
return q.to(k), k.to(q)
|
||||||
|
|
||||||
|
|
||||||
@@ -128,8 +128,8 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
B, L, _ = qkv.shape
|
B, L, _ = qkv.shape
|
||||||
qkv = qkv.view(B, L, 3, self.num_heads, -1) # Split into Q, K, V
|
qkv = qkv.view(B, L, 3, self.num_heads, -1)
|
||||||
q, k, v = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D)
|
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
||||||
del qkv
|
del qkv
|
||||||
|
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|||||||
Reference in New Issue
Block a user