SAG + Kohya HRFix: different method (#2304)

store shape after Kohya HRFix is applied, use it in SAG instead of trying to calculate. AFAICT, calculation can never be 100% correct due to rounding. Original method is tried first.
additional stored shape can be used by other extensions, e.g. Forge Couple, to enable compatibility with Kohya HRFix
This commit is contained in:
DenOfEquity
2024-12-06 21:36:58 +00:00
committed by GitHub
parent 9fbba69297
commit 85a7db3c0f
2 changed files with 36 additions and 19 deletions

View File

@@ -1,6 +1,7 @@
import gradio as gr
from modules import scripts
from modules import scripts, shared
from modules.ui_components import InputAccordion
from backend.misc.image_resize import adaptive_resize
@@ -14,11 +15,16 @@ class PatchModelAddDownscale:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
shared.kohya_shrink_shape = (h.shape[-1], h.shape[-2])
shared.kohya_shrink_shape_out = None
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
shared.kohya_shrink_shape_out = (h.shape[-1], h.shape[-2])
return h, hsp
m = model.clone()
@@ -44,15 +50,28 @@ class KohyaHRFixForForge(scripts.Script):
def ui(self, *args, **kwargs):
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
with gr.Accordion(open=False, label=self.title()):
enabled = gr.Checkbox(label='Enabled', value=False)
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
with InputAccordion(False, label=self.title()) as enabled:
with gr.Row():
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
with gr.Row():
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
downscale_after_skip = gr.Checkbox(label='Downscale After Skip', value=True)
downscale_method = gr.Radio(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
upscale_method = gr.Radio(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
with gr.Row():
downscale_method = gr.Dropdown(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
upscale_method = gr.Dropdown(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
self.infotext_fields = [
(enabled, lambda d: d.get("kohya_hrfix_enabled", False)),
(block_number, "kohya_hrfix_block_number"),
(downscale_factor, "kohya_hrfix_downscale_factor"),
(start_percent, "kohya_hrfix_start_percent"),
(end_percent, "kohya_hrfix_end_percent"),
(downscale_after_skip, "kohya_hrfix_downscale_after_skip"),
(downscale_method, "kohya_hrfix_downscale_method"),
(upscale_method, "kohya_hrfix_upscale_method"),
]
return enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method

View File

@@ -62,20 +62,18 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
# original method: works for all normal inputs that *do not* have Kohya HRFix scaling; typically fails with scaling
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
h = math.ceil(lh / ratio)
w = math.ceil(lw / ratio)
if h * w != mask.size(1):
# this new calculation, to work with Kohya HRFix, sometimes incorrectly rounds up w or h
# so we only use it if the original method failed to calculate correct w, h
f = float(lh) / float(lw)
fh = f ** 0.5
fw = (1/f) ** 0.5
S = mask.size(1) ** 0.5
w = int(0.5 + S * fw)
h = int(0.5 + S * fh)
kohya_shrink_shape = getattr(shared, 'kohya_shrink_shape', None)
if kohya_shrink_shape:
w = kohya_shrink_shape[0] # works with all block numbers for kohya hrfix
h = kohya_shrink_shape[1]
# Reshape
mask = (
mask.reshape(b, h, w)