From 5c79b0d96d1f1cc90b62ea4347efc2130e4696e0 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 16:09:06 -0800 Subject: [PATCH] Update forge_freeu.py --- .../sd_forge_freeu/scripts/forge_freeu.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py b/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py index 55772999..18b01ab0 100644 --- a/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py +++ b/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py @@ -1,7 +1,43 @@ +import torch import gradio as gr from modules import scripts +def Fourier_filter(x, threshold, scale): + x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real + return x_filtered.to(x.dtype) + + +def patch(model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + + def output_block_patch(h, hsp, transformer_options): + scale = scale_dict.get(h.shape[1], None) + if scale is not None: + hidden_mean = h.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / \ + (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + h[:, :h.shape[1] // 2] = h[:, :h.shape[1] // 2] * ((scale[0] - 1) * hidden_mean + 1) + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return m + + class FreeUForForge(scripts.Script): def title(self): return "FreeU Integrated" @@ -22,7 +58,7 @@ class FreeUForForge(scripts.Script): def process_batch(self, p, *script_args, **kwargs): freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args - + if not freeu_enabled: return