diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 035c46a3..0b80ec3a 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -108,9 +108,9 @@ class QKNorm(nn.Module): self.key_norm = RMSNorm(dim) def forward(self, q, k, v): + del v q = self.query_norm(q) k = self.key_norm(k) - del v return q.to(k), k.to(q) @@ -128,8 +128,8 @@ class SelfAttention(nn.Module): # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) B, L, _ = qkv.shape - qkv = qkv.view(B, L, 3, self.num_heads, -1) # Split into Q, K, V - q, k, v = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D) + qkv = qkv.view(B, L, 3, self.num_heads, -1) + q, k, v = qkv.permute(2, 0, 3, 1, 4) del qkv q, k = self.norm(q, k, v)