diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 1a54992c..ae5144f6 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -129,8 +129,7 @@ class SelfAttention(nn.Module): # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) B, L, _ = qkv.shape qkv = qkv.view(B, L, 3, self.num_heads, -1) # Split into Q, K, V - qkv = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D) - q, k, v = qkv[0], qkv[1], qkv[2] # Separate Q, K, V + q, k, v = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D) del qkv q, k = self.norm(q, k, v) @@ -196,10 +195,7 @@ class DoubleStreamBlock(nn.Module): B, L, _ = img_qkv.shape H = self.num_heads D = img_qkv.shape[-1] // (3 * H) - img_qkv = img_qkv.view(B, L, 3, H, D) - img_q = img_qkv[:, :, 0, :, :].permute(2, 0, 1, 3) - img_k = img_qkv[:, :, 1, :, :].permute(2, 0, 1, 3) - img_v = img_qkv[:, :, 2, :, :].permute(2, 0, 1, 3) + img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 1, 3) del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) @@ -212,10 +208,7 @@ class DoubleStreamBlock(nn.Module): # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) B, L, _ = txt_qkv.shape - txt_qkv = txt_qkv.view(B, L, 3, H, D) - txt_q = txt_qkv[:, :, 0, :, :].permute(2, 0, 1, 3) - txt_k = txt_qkv[:, :, 1, :, :].permute(2, 0, 1, 3) - txt_v = txt_qkv[:, :, 2, :, :].permute(2, 0, 1, 3) + txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 1, 3) del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)