a bit more cuda profile

This commit is contained in:
layerdiffusion
2024-08-08 20:42:50 -07:00
parent 6f254f3599
commit 6f7aea20b3

View File

@@ -180,7 +180,6 @@ class DoubleStreamBlock(nn.Module):
def forward(self, img, txt, vec, pe):
img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec)
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift
@@ -197,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec)
del vec
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift
del txt_mod1_shift, txt_mod1_scale
@@ -218,6 +220,7 @@ class DoubleStreamBlock(nn.Module):
del txt_v, img_v
attn = attention(q, k, v, pe=pe)
del pe
txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
del attn