add fp16_fix

This commit is contained in:
layerdiffusion
2024-08-14 17:10:03 -07:00
parent aadc0f04c4
commit b09c24ef51
3 changed files with 21 additions and 2 deletions

View File

@@ -9,6 +9,7 @@ import torch
from torch import nn
from einops import rearrange, repeat
from backend.attention import attention_function
from backend.utils import fp16_fix
def attention(q, k, v, pe):
@@ -242,6 +243,8 @@ class DoubleStreamBlock(nn.Module):
txt = txt + txt_mod2_gate * self.txt_mlp((1 + txt_mod2_scale) * self.txt_norm2(txt) + txt_mod2_shift)
del txt_mod2_gate, txt_mod2_scale, txt_mod2_shift
txt = fp16_fix(txt)
return img, txt
@@ -279,7 +282,13 @@ class SingleStreamBlock(nn.Module):
del q, k, v, pe
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
del attn, mlp
return x + mod_gate * output
x = x + mod_gate * output
del mod_gate, output
x = fp16_fix(x)
return x
class LastLayer(nn.Module):