diff --git a/extensions-builtin/sd_forge_kohya_hrfix/scripts/kohya_hrfix.py b/extensions-builtin/sd_forge_kohya_hrfix/scripts/kohya_hrfix.py index e872c0c7..21aa248b 100644 --- a/extensions-builtin/sd_forge_kohya_hrfix/scripts/kohya_hrfix.py +++ b/extensions-builtin/sd_forge_kohya_hrfix/scripts/kohya_hrfix.py @@ -1,6 +1,7 @@ import gradio as gr -from modules import scripts +from modules import scripts, shared +from modules.ui_components import InputAccordion from backend.misc.image_resize import adaptive_resize @@ -14,11 +15,16 @@ class PatchModelAddDownscale: sigma = transformer_options["sigmas"][0].item() if sigma <= sigma_start and sigma >= sigma_end: h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") + + shared.kohya_shrink_shape = (h.shape[-1], h.shape[-2]) + shared.kohya_shrink_shape_out = None return h def output_block_patch(h, hsp, transformer_options): if h.shape[2] != hsp.shape[2]: h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") + + shared.kohya_shrink_shape_out = (h.shape[-1], h.shape[-2]) return h, hsp m = model.clone() @@ -44,15 +50,28 @@ class KohyaHRFixForForge(scripts.Script): def ui(self, *args, **kwargs): upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] - with gr.Accordion(open=False, label=self.title()): - enabled = gr.Checkbox(label='Enabled', value=False) - block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1) - downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001) - start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001) - end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001) + with InputAccordion(False, label=self.title()) as enabled: + with gr.Row(): + block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1) + downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001) + with gr.Row(): + start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001) + end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001) downscale_after_skip = gr.Checkbox(label='Downscale After Skip', value=True) - downscale_method = gr.Radio(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0]) - upscale_method = gr.Radio(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0]) + with gr.Row(): + downscale_method = gr.Dropdown(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0]) + upscale_method = gr.Dropdown(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0]) + + self.infotext_fields = [ + (enabled, lambda d: d.get("kohya_hrfix_enabled", False)), + (block_number, "kohya_hrfix_block_number"), + (downscale_factor, "kohya_hrfix_downscale_factor"), + (start_percent, "kohya_hrfix_start_percent"), + (end_percent, "kohya_hrfix_end_percent"), + (downscale_after_skip, "kohya_hrfix_downscale_after_skip"), + (downscale_method, "kohya_hrfix_downscale_method"), + (upscale_method, "kohya_hrfix_upscale_method"), + ] return enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method diff --git a/extensions-builtin/sd_forge_sag/scripts/forge_sag.py b/extensions-builtin/sd_forge_sag/scripts/forge_sag.py index 8491ca3f..e0786d63 100644 --- a/extensions-builtin/sd_forge_sag/scripts/forge_sag.py +++ b/extensions-builtin/sd_forge_sag/scripts/forge_sag.py @@ -62,20 +62,18 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + + # original method: works for all normal inputs that *do not* have Kohya HRFix scaling; typically fails with scaling ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() h = math.ceil(lh / ratio) w = math.ceil(lw / ratio) - + if h * w != mask.size(1): - # this new calculation, to work with Kohya HRFix, sometimes incorrectly rounds up w or h - # so we only use it if the original method failed to calculate correct w, h - f = float(lh) / float(lw) - fh = f ** 0.5 - fw = (1/f) ** 0.5 - S = mask.size(1) ** 0.5 - w = int(0.5 + S * fw) - h = int(0.5 + S * fh) - + kohya_shrink_shape = getattr(shared, 'kohya_shrink_shape', None) + if kohya_shrink_shape: + w = kohya_shrink_shape[0] # works with all block numbers for kohya hrfix + h = kohya_shrink_shape[1] + # Reshape mask = ( mask.reshape(b, h, w)