diff --git a/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py b/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py index 1af2f132..76823d9b 100644 --- a/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py +++ b/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py @@ -1,41 +1,45 @@ import torch import gradio as gr from modules import scripts +from ldm_patched.contrib.external_freelunch import FreeU_V2 -def Fourier_filter(x, threshold, scale): - x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) - x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) - B, C, H, W = x_freq.shape - mask = torch.ones((B, C, H, W), device=x.device) - crow, ccol = H // 2, W //2 - mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale - x_freq = x_freq * mask - x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) - x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real - return x_filtered.to(x.dtype) +opFreeU_V2 = FreeU_V2() -def set_freeu_v2_patch(model, b1, b2, s1, s2): - model_channels = model.model.model_config.unet_config["model_channels"] - scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} - - def output_block_patch(h, hsp, *args, **kwargs): - scale = scale_dict.get(h.shape[1], None) - if scale is not None: - hidden_mean = h.mean(1).unsqueeze(1) - B = hidden_mean.shape[0] - hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) - hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) - hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / \ - (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) - h[:, :h.shape[1] // 2] = h[:, :h.shape[1] // 2] * ((scale[0] - 1) * hidden_mean + 1) - hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) - return h, hsp - - m = model.clone() - m.set_model_output_block_patch(output_block_patch) - return m +# def Fourier_filter(x, threshold, scale): +# x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) +# x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) +# B, C, H, W = x_freq.shape +# mask = torch.ones((B, C, H, W), device=x.device) +# crow, ccol = H // 2, W //2 +# mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale +# x_freq = x_freq * mask +# x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) +# x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real +# return x_filtered.to(x.dtype) +# +# +# def set_freeu_v2_patch(model, b1, b2, s1, s2): +# model_channels = model.model.model_config.unet_config["model_channels"] +# scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} +# +# def output_block_patch(h, hsp, *args, **kwargs): +# scale = scale_dict.get(h.shape[1], None) +# if scale is not None: +# hidden_mean = h.mean(1).unsqueeze(1) +# B = hidden_mean.shape[0] +# hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) +# hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) +# hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / \ +# (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) +# h[:, :h.shape[1] // 2] = h[:, :h.shape[1] // 2] * ((scale[0] - 1) * hidden_mean + 1) +# hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) +# return h, hsp +# +# m = model.clone() +# m.set_model_output_block_patch(output_block_patch) +# return m class FreeUForForge(scripts.Script): @@ -64,7 +68,8 @@ class FreeUForForge(scripts.Script): unet = p.sd_model.forge_objects.unet - unet = set_freeu_v2_patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2) + # unet = set_freeu_v2_patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2) + unet = opFreeU_V2.patch(unet, freeu_b1, freeu_b2, freeu_s1, freeu_s2)[0] p.sd_model.forge_objects.unet = unet