mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 17:09:49 +00:00
optimization part 2
https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15804
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Single File Implementation of Flux with aggressive optimizations, copyright Forge 2024
|
||||
# Single File Implementation of Flux with aggressive optimizations, Copyright Forge 2024
|
||||
# If used outside Forge, only non-commercial use is allowed.
|
||||
# See also https://github.com/black-forest-labs/flux
|
||||
|
||||
@@ -20,9 +20,19 @@ def attention(q, k, v, pe):
|
||||
def rope(pos, dim, theta):
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta ** scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
|
||||
# out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
||||
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
del cos_out, sin_out
|
||||
|
||||
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
b, n, d, _ = out.shape
|
||||
out = out.view(b, n, d, 2, 2)
|
||||
|
||||
return out.float()
|
||||
|
||||
|
||||
@@ -115,11 +125,19 @@ class SelfAttention(nn.Module):
|
||||
|
||||
def forward(self, x, pe):
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
B, L, _ = qkv.shape
|
||||
qkv = qkv.view(B, L, 3, self.num_heads, -1) # Split into Q, K, V
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # Separate Q, K, V
|
||||
del qkv
|
||||
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
x = attention(q, k, v, pe=pe)
|
||||
del q, k, v
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
@@ -173,8 +191,17 @@ class DoubleStreamBlock(nn.Module):
|
||||
del img_mod1_shift, img_mod1_scale
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
B, L, _ = img_qkv.shape
|
||||
H = self.num_heads
|
||||
D = img_qkv.shape[-1] // (3 * H)
|
||||
img_qkv = img_qkv.view(B, L, 3, H, D)
|
||||
img_q = img_qkv[:, :, 0, :, :].permute(2, 0, 1, 3)
|
||||
img_k = img_qkv[:, :, 1, :, :].permute(2, 0, 1, 3)
|
||||
img_v = img_qkv[:, :, 2, :, :].permute(2, 0, 1, 3)
|
||||
del img_qkv
|
||||
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
@@ -182,8 +209,15 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_mod1_shift, txt_mod1_scale
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
B, L, _ = txt_qkv.shape
|
||||
txt_qkv = txt_qkv.view(B, L, 3, H, D)
|
||||
txt_q = txt_qkv[:, :, 0, :, :].permute(2, 0, 1, 3)
|
||||
txt_k = txt_qkv[:, :, 1, :, :].permute(2, 0, 1, 3)
|
||||
txt_v = txt_qkv[:, :, 2, :, :].permute(2, 0, 1, 3)
|
||||
del txt_qkv
|
||||
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
@@ -194,7 +228,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_v, img_v
|
||||
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, :, :txt.shape[1], :], attn[:, :, txt.shape[1]:, :]
|
||||
del attn
|
||||
|
||||
img = img + img_mod1_gate * self.img_attn.proj(img_attn)
|
||||
@@ -233,12 +267,17 @@ class SingleStreamBlock(nn.Module):
|
||||
del mod_shift, mod_scale
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
del x_mod
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
del qkv
|
||||
|
||||
q, k = self.norm(q, k, v)
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
del q, k, v, pe
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2))
|
||||
del attn, mlp
|
||||
return x + mod_gate * output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user