diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 0b80ec3a..d12bd30c 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -180,7 +180,6 @@ class DoubleStreamBlock(nn.Module): def forward(self, img, txt, vec, pe): img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = self.img_mod(vec) - txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec) img_modulated = self.img_norm1(img) img_modulated = (1 + img_mod1_scale) * img_modulated + img_mod1_shift @@ -197,6 +196,9 @@ class DoubleStreamBlock(nn.Module): img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = self.txt_mod(vec) + del vec + txt_modulated = self.txt_norm1(txt) txt_modulated = (1 + txt_mod1_scale) * txt_modulated + txt_mod1_shift del txt_mod1_shift, txt_mod1_scale @@ -218,6 +220,7 @@ class DoubleStreamBlock(nn.Module): del txt_v, img_v attn = attention(q, k, v, pe=pe) + del pe txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:] del attn