mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-09 06:59:48 +00:00
SAG + Kohya HRFix: different method (#2304)
store shape after Kohya HRFix is applied, use it in SAG instead of trying to calculate. AFAICT, calculation can never be 100% correct due to rounding. Original method is tried first. additional stored shape can be used by other extensions, e.g. Forge Couple, to enable compatibility with Kohya HRFix
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts
|
||||
from modules import scripts, shared
|
||||
from modules.ui_components import InputAccordion
|
||||
from backend.misc.image_resize import adaptive_resize
|
||||
|
||||
|
||||
@@ -14,11 +15,16 @@ class PatchModelAddDownscale:
|
||||
sigma = transformer_options["sigmas"][0].item()
|
||||
if sigma <= sigma_start and sigma >= sigma_end:
|
||||
h = adaptive_resize(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
|
||||
|
||||
shared.kohya_shrink_shape = (h.shape[-1], h.shape[-2])
|
||||
shared.kohya_shrink_shape_out = None
|
||||
return h
|
||||
|
||||
def output_block_patch(h, hsp, transformer_options):
|
||||
if h.shape[2] != hsp.shape[2]:
|
||||
h = adaptive_resize(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
|
||||
|
||||
shared.kohya_shrink_shape_out = (h.shape[-1], h.shape[-2])
|
||||
return h, hsp
|
||||
|
||||
m = model.clone()
|
||||
@@ -44,15 +50,28 @@ class KohyaHRFixForForge(scripts.Script):
|
||||
|
||||
def ui(self, *args, **kwargs):
|
||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||
with gr.Accordion(open=False, label=self.title()):
|
||||
enabled = gr.Checkbox(label='Enabled', value=False)
|
||||
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
|
||||
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
|
||||
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
|
||||
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
|
||||
with InputAccordion(False, label=self.title()) as enabled:
|
||||
with gr.Row():
|
||||
block_number = gr.Slider(label='Block Number', value=3, minimum=1, maximum=32, step=1)
|
||||
downscale_factor = gr.Slider(label='Downscale Factor', value=2.0, minimum=0.1, maximum=9.0, step=0.001)
|
||||
with gr.Row():
|
||||
start_percent = gr.Slider(label='Start Percent', value=0.0, minimum=0.0, maximum=1.0, step=0.001)
|
||||
end_percent = gr.Slider(label='End Percent', value=0.35, minimum=0.0, maximum=1.0, step=0.001)
|
||||
downscale_after_skip = gr.Checkbox(label='Downscale After Skip', value=True)
|
||||
downscale_method = gr.Radio(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
upscale_method = gr.Radio(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
with gr.Row():
|
||||
downscale_method = gr.Dropdown(label='Downscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
upscale_method = gr.Dropdown(label='Upscale Method', choices=upscale_methods, value=upscale_methods[0])
|
||||
|
||||
self.infotext_fields = [
|
||||
(enabled, lambda d: d.get("kohya_hrfix_enabled", False)),
|
||||
(block_number, "kohya_hrfix_block_number"),
|
||||
(downscale_factor, "kohya_hrfix_downscale_factor"),
|
||||
(start_percent, "kohya_hrfix_start_percent"),
|
||||
(end_percent, "kohya_hrfix_end_percent"),
|
||||
(downscale_after_skip, "kohya_hrfix_downscale_after_skip"),
|
||||
(downscale_method, "kohya_hrfix_downscale_method"),
|
||||
(upscale_method, "kohya_hrfix_upscale_method"),
|
||||
]
|
||||
|
||||
return enabled, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method
|
||||
|
||||
|
||||
@@ -62,20 +62,18 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
|
||||
# original method: works for all normal inputs that *do not* have Kohya HRFix scaling; typically fails with scaling
|
||||
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
||||
h = math.ceil(lh / ratio)
|
||||
w = math.ceil(lw / ratio)
|
||||
|
||||
|
||||
if h * w != mask.size(1):
|
||||
# this new calculation, to work with Kohya HRFix, sometimes incorrectly rounds up w or h
|
||||
# so we only use it if the original method failed to calculate correct w, h
|
||||
f = float(lh) / float(lw)
|
||||
fh = f ** 0.5
|
||||
fw = (1/f) ** 0.5
|
||||
S = mask.size(1) ** 0.5
|
||||
w = int(0.5 + S * fw)
|
||||
h = int(0.5 + S * fh)
|
||||
|
||||
kohya_shrink_shape = getattr(shared, 'kohya_shrink_shape', None)
|
||||
if kohya_shrink_shape:
|
||||
w = kohya_shrink_shape[0] # works with all block numbers for kohya hrfix
|
||||
h = kohya_shrink_shape[1]
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, h, w)
|
||||
|
||||
Reference in New Issue
Block a user