From 8d336c3247b03d40f5179d42c8a8f0fe285fd7a9 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:13:39 -0700 Subject: [PATCH] fix comments --- backend/nn/flux.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 712d6d5d..1a54992c 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -192,7 +192,7 @@ class DoubleStreamBlock(nn.Module): img_qkv = self.img_attn.qkv(img_modulated) del img_modulated - # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) B, L, _ = img_qkv.shape H = self.num_heads D = img_qkv.shape[-1] // (3 * H) @@ -210,7 +210,7 @@ class DoubleStreamBlock(nn.Module): txt_qkv = self.txt_attn.qkv(txt_modulated) del txt_modulated - # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) B, L, _ = txt_qkv.shape txt_qkv = txt_qkv.view(B, L, 3, H, D) txt_q = txt_qkv[:, :, 0, :, :].permute(2, 0, 1, 3)