From 6f7aea20b3de73a3c60b4df05663d31eb73a4310 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 8 Aug 2024 20:42:50 -0700 Subject: [PATCH] a bit more cuda profile --- backend/nn/flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 0b80ec3a..d12bd30c 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -180,7 +180,6 @@ class DoubleStreamBlock(nn.Module): def forward(self, img, txt, vec, pe): img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec) - txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec) img_modulated = self.img_norm1(img) img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift @@ -197,6 +196,9 @@ class DoubleStreamBlock(nn.Module): img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec) + del vec + txt_modulated = self.txt_norm1(txt) txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift del txt_mod1_shift, txt_mod1_scale @@ -218,6 +220,7 @@ class DoubleStreamBlock(nn.Module): del txt_v, img_v attn = attention(q, k, v, pe=pe) + del pe txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:] del attn