From 79adfa8998c7b6c448d188cc469eda3540f17b53 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Thu, 8 Aug 2024 16:09:36 -0700 Subject: [PATCH] optimization part 2 https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15804 --- backend/nn/flux.py | 59 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/backend/nn/flux.py b/backend/nn/flux.py index 256684cb..712d6d5d 100644 --- a/backend/nn/flux.py +++ b/backend/nn/flux.py @@ -1,4 +1,4 @@ -# Single File Implementation of Flux with aggressive optimizations, copyright Forge 2024 +# Single File Implementation of Flux with aggressive optimizations, Copyright Forge 2024 # If used outside Forge, only non-commercial use is allowed. # See also https://github.com/black-forest-labs/flux @@ -20,9 +20,19 @@ def attention(q, k, v, pe): def rope(pos, dim, theta): scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta ** scale) - out = torch.einsum("...n,d->...nd", pos, omega) - out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) - out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + + # out = torch.einsum("...n,d->...nd", pos, omega) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + + cos_out = torch.cos(out) + sin_out = torch.sin(out) + out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) + del cos_out, sin_out + + # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) + b, n, d, _ = out.shape + out = out.view(b, n, d, 2, 2) + return out.float() @@ -115,11 +125,19 @@ class SelfAttention(nn.Module): def forward(self, x, pe): qkv = self.qkv(x) - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + B, L, _ = qkv.shape + qkv = qkv.view(B, L, 3, self.num_heads, -1) # Split into Q, K, V + qkv = qkv.permute(2, 0, 3, 1, 4) # Rearrange to (K B H L D) + q, k, v = qkv[0], qkv[1], qkv[2] # Separate Q, K, V del qkv + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) del q, k, v + x = self.proj(x) return x @@ -173,8 +191,17 @@ class DoubleStreamBlock(nn.Module): del img_mod1_shift, img_mod1_scale img_qkv = self.img_attn.qkv(img_modulated) del img_modulated - img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + B, L, _ = img_qkv.shape + H = self.num_heads + D = img_qkv.shape[-1] // (3 * H) + img_qkv = img_qkv.view(B, L, 3, H, D) + img_q = img_qkv[:, :, 0, :, :].permute(2, 0, 1, 3) + img_k = img_qkv[:, :, 1, :, :].permute(2, 0, 1, 3) + img_v = img_qkv[:, :, 2, :, :].permute(2, 0, 1, 3) del img_qkv + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) txt_modulated = self.txt_norm1(txt) @@ -182,8 +209,15 @@ class DoubleStreamBlock(nn.Module): del txt_mod1_shift, txt_mod1_scale txt_qkv = self.txt_attn.qkv(txt_modulated) del txt_modulated - txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + B, L, _ = txt_qkv.shape + txt_qkv = txt_qkv.view(B, L, 3, H, D) + txt_q = txt_qkv[:, :, 0, :, :].permute(2, 0, 1, 3) + txt_k = txt_qkv[:, :, 1, :, :].permute(2, 0, 1, 3) + txt_v = txt_qkv[:, :, 2, :, :].permute(2, 0, 1, 3) del txt_qkv + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) q = torch.cat((txt_q, img_q), dim=2) @@ -194,7 +228,7 @@ class DoubleStreamBlock(nn.Module): del txt_v, img_v attn = attention(q, k, v, pe=pe) - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] + txt_attn, img_attn = attn[:, :, :txt.shape[1], :], attn[:, :, txt.shape[1]:, :] del attn img = img + img_mod1_gate * self.img_attn.proj(img_attn) @@ -233,12 +267,17 @@ class SingleStreamBlock(nn.Module): del mod_shift, mod_scale qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) del x_mod - q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + + # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) + q, k, v = qkv[0], qkv[1], qkv[2] del qkv + q, k = self.norm(q, k, v) attn = attention(q, k, v, pe=pe) del q, k, v, pe - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), dim=2)) del attn, mlp return x + mod_gate * output