This commit is contained in:
layerdiffusion
2024-08-08 16:50:55 -07:00
parent 10e4a3779d
commit a1c2764e4a

View File

@@ -108,9 +108,9 @@ class QKNorm(nn.Module):
self.key_norm = RMSNorm(dim)
def forward(self, q, k, v):
del v
q = self.query_norm(q)
k = self.key_norm(k)
del v
return q.to(k), k.to(q)
@@ -128,8 +128,8 @@ class SelfAttention(nn.Module):
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
B, L, _ = qkv.shape
qkv = qkv.view(B, L, 3, self.num_heads, -1) # Split into Q, K, V
q, k, v = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D)
qkv = qkv.view(B, L, 3, self.num_heads, -1)
q, k, v = qkv.permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v)