mirror of
https://github.com/altoiddealer/--sd-webui-ar-plusplus.git
synced 2026-01-26 19:19:55 +00:00
Most credit goes to [LEv145](https://github.com/LEv145) for solving the inheritance issue in sd-webui-forge (due to their implementation of Gradio 4)
471 lines
18 KiB
Python
471 lines
18 KiB
Python
import typing as t
|
||
import contextlib
|
||
from pathlib import Path
|
||
|
||
import gradio as gr
|
||
|
||
import modules.scripts as scripts
|
||
from modules.ui_components import ToolButton
|
||
|
||
from fractions import Fraction
|
||
from math import sqrt
|
||
|
||
BASE_PATH = scripts.basedir()
|
||
|
||
class ButtonState():
|
||
def __init__(self):
|
||
self.is_locked = False
|
||
self.switched = False
|
||
self.alt_mode = False
|
||
def toggle_lock(self):
|
||
self.is_locked = not self.is_locked
|
||
def toggle_switch(self):
|
||
self.switched = not self.switched
|
||
def toggle_mode(self):
|
||
self.alt_mode = not self.alt_mode
|
||
|
||
txt2img_state = ButtonState()
|
||
img2img_state = ButtonState()
|
||
|
||
# Helper functions for calculating new width/height values
|
||
def round_to_precision(val, prec):
|
||
return round(val / prec) * prec
|
||
|
||
def res_to_model_fit(avg, w, h, prec):
|
||
mp = w * h
|
||
mp_target = avg * avg
|
||
scale = sqrt(mp_target / mp)
|
||
w = int(round_to_precision(w * scale, prec))
|
||
h = int(round_to_precision(h * scale, prec))
|
||
return w, h
|
||
|
||
def calc_width(n, d, w, h, prec):
|
||
ar = round((n / d), 2) # Convert AR parts to fraction
|
||
if ar > 1.0:
|
||
h = w / ar
|
||
elif ar < 1.0:
|
||
w = h * ar
|
||
else:
|
||
new_value = max([w, h])
|
||
w, h = new_value, new_value
|
||
w = int(round_to_precision((w + prec / 2), prec))
|
||
h = int(round_to_precision((h + prec / 2), prec))
|
||
return w, h
|
||
|
||
def calc_height(n, d, w, h, prec):
|
||
ar = round((n / d), 2) # Convert AR parts to fraction
|
||
if ar > 1.0:
|
||
w = h * ar
|
||
elif ar < 1.0:
|
||
h = w / ar
|
||
else:
|
||
new_value = min([w, h])
|
||
w, h = new_value, new_value
|
||
w = int(round_to_precision((w + prec / 2), prec))
|
||
h = int(round_to_precision((h + prec / 2), prec))
|
||
return w, h
|
||
|
||
def dims_from_ar(avg, n, d, prec):
|
||
doubleavg = avg * 2
|
||
ar_sum = n+d
|
||
# calculate width and height by factoring average with aspect ratio
|
||
w = round((n / ar_sum) * doubleavg)
|
||
h = round((d / ar_sum) * doubleavg)
|
||
# Round to correct megapixel precision
|
||
w, h = res_to_model_fit(avg, w, h, prec)
|
||
return w, h
|
||
|
||
def avg_from_dims(w, h):
|
||
avg = (w + h) // 2
|
||
if (w + h) % 2 != 0:
|
||
avg += 1
|
||
return avg
|
||
|
||
## Aspect Ratio buttons
|
||
def create_ar_button_function(ar:str, is_img2img:bool):
|
||
def wrapper(avg, prec, w=512, h=512):
|
||
# Determine the state based on whether it's img2img or txt2img
|
||
state = img2img_state if is_img2img else txt2img_state
|
||
|
||
n, d = map(Fraction, ar.split(':')) # Split numerator and denominator
|
||
|
||
if not state.is_locked:
|
||
avg = avg_from_dims(w, h) # Get average of current width/height values
|
||
|
||
if not state.alt_mode: # True = offset, False = One dimension
|
||
w, h = dims_from_ar(avg, n, d, prec) # Calculate new w + h from avg, AR, and precision
|
||
if state.switched: # Switch results if switch mode is active
|
||
w, h = h, w
|
||
else: # Calculate w or h from input, AR, and precision
|
||
if state.switched: # Switch results if switch mode is active
|
||
w, h = calc_width(n, d, w, h, prec) # Modify width
|
||
else:
|
||
w, h = calc_height(n, d, w, h, prec) # Modify height
|
||
|
||
return avg, w, h
|
||
|
||
return wrapper
|
||
|
||
def create_ar_buttons(
|
||
lst: t.Iterable[str],
|
||
is_img2img: bool,
|
||
) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]:
|
||
buttons = []
|
||
functions = {}
|
||
|
||
for ar in lst:
|
||
button = ToolButton(ar, render=False)
|
||
function = create_ar_button_function(ar, is_img2img)
|
||
buttons.append(button)
|
||
functions[button] = function
|
||
|
||
return buttons, functions
|
||
|
||
## Static Resolution buttons
|
||
def create_res_button_function(w:int, h:int, is_img2img:bool):
|
||
def wrapper(avg):
|
||
state = img2img_state if is_img2img else txt2img_state
|
||
if not state.is_locked:
|
||
avg = avg_from_dims(w, h)
|
||
return avg, w, h
|
||
return wrapper
|
||
|
||
def create_res_buttons(
|
||
lst: t.Iterable[t.Tuple[t.List[int], str]],
|
||
is_img2img: bool,
|
||
) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]:
|
||
buttons = []
|
||
functions = {}
|
||
|
||
for resolution, label in lst:
|
||
button = ToolButton(label)
|
||
w, h = resolution
|
||
function = create_res_button_function(w, h, is_img2img)
|
||
buttons.append(button)
|
||
functions[button] = function
|
||
|
||
return buttons, functions
|
||
|
||
# Get values for Aspect Ratios from file
|
||
def parse_aspect_ratios_file(filename):
|
||
values, flipvals, comments = [], [], []
|
||
file = Path(BASE_PATH, filename)
|
||
|
||
if not file.exists():
|
||
return values, comments, flipvals
|
||
|
||
with open(file, "r", encoding="utf-8") as f:
|
||
lines = f.readlines()
|
||
|
||
if not lines:
|
||
return values, comments, flipvals
|
||
|
||
for line in lines:
|
||
if line.startswith("#"):
|
||
continue
|
||
|
||
value = line.strip()
|
||
|
||
comment = ""
|
||
if "#" in value:
|
||
value, comment = value.split("#")
|
||
value = value.strip()
|
||
values.append(value)
|
||
comments.append(comment)
|
||
|
||
comp1, comp2 = value.split(':')
|
||
flipval = f"{comp2}:{comp1}"
|
||
flipvals.append(flipval)
|
||
|
||
return values, comments, flipvals
|
||
|
||
# Get values for Static Resolutions from file
|
||
def parse_resolutions_file(filename):
|
||
labels, values, comments = [], [], []
|
||
file = Path(BASE_PATH, filename)
|
||
|
||
if not file.exists():
|
||
return labels, values, comments
|
||
|
||
with open(file, "r", encoding="utf-8") as f:
|
||
lines = f.readlines()
|
||
|
||
if not lines:
|
||
return labels, values, comments
|
||
|
||
for line in lines:
|
||
if line.startswith("#"):
|
||
continue
|
||
|
||
label, width, height = line.strip().split(",")
|
||
comment = ""
|
||
if "#" in height:
|
||
height, comment = height.split("#")
|
||
|
||
resolution = (width, height)
|
||
|
||
labels.append(label)
|
||
values.append(resolution)
|
||
comments.append(comment)
|
||
|
||
return labels, values, comments
|
||
|
||
def write_aspect_ratios_file(filename):
|
||
aspect_ratios = [
|
||
"1:1 # Square\n",
|
||
"4:3 # Television Photography\n",
|
||
"3:2 # Photography\n",
|
||
"8:5 # Widescreen Displays\n",
|
||
"16:9 # Widescreen Television\n",
|
||
"21:9 # Ultrawide Cinematography"
|
||
]
|
||
with open(filename, "w", encoding="utf-8") as f:
|
||
f.writelines(aspect_ratios)
|
||
|
||
def write_resolutions_file(filename):
|
||
resolutions = [
|
||
"512, 512, 512 # 512x512\n",
|
||
"768, 768, 768 # 768x768\n",
|
||
"1024, 1024, 1024 # 1024x1024\n",
|
||
"1280, 1280, 1280 # 1280x1280\n",
|
||
"1536, 1536, 1536 # 1536x1536\n",
|
||
"2048, 2048, 2048 # 2048x2048",
|
||
]
|
||
with open(filename, "w", encoding="utf-8") as f:
|
||
f.writelines(resolutions)
|
||
|
||
def write_js_titles_file(button_titles):
|
||
filename = Path(BASE_PATH, "javascript", "button_titles.js")
|
||
content = ["// Do not put custom titles here. This file is overwritten each time the WebUI is started.\n"]
|
||
content.append("arsp__ar_button_titles = {\n")
|
||
counter = 0
|
||
while counter < len(button_titles[0]):
|
||
content.append(f' " {button_titles[0][counter]}" : "{button_titles[1][counter]}",\n')
|
||
counter = counter + 1
|
||
content.append("}")
|
||
|
||
with open(filename, "w", encoding="utf-8") as f:
|
||
f.writelines(content)
|
||
|
||
class AspectRatioScript(scripts.Script):
|
||
def read_aspect_ratios(self):
|
||
ar_file = Path(BASE_PATH, "aspect_ratios.txt")
|
||
if not ar_file.exists():
|
||
write_aspect_ratios_file(ar_file)
|
||
(self.aspect_ratios, self.aspect_ratio_comments, self.flipped_vals) = parse_aspect_ratios_file("aspect_ratios.txt")
|
||
self.ar_buttons_labels = self.aspect_ratios
|
||
|
||
def read_resolutions(self):
|
||
res_file = Path(BASE_PATH, "resolutions.txt")
|
||
if not res_file.exists():
|
||
write_resolutions_file(res_file)
|
||
|
||
self.res_labels, res, self.res_comments = parse_resolutions_file("resolutions.txt")
|
||
self.res = [list(map(int, r)) for r in res]
|
||
|
||
def title(self):
|
||
return "Aspect Ratio picker"
|
||
|
||
def show(self, is_img2img):
|
||
return scripts.AlwaysVisible
|
||
|
||
def ui(self, is_img2img):
|
||
self.LOCK_OPEN_ICON = "\U0001F513" # 🔓
|
||
self.LOCK_CLOSED_ICON = "\U0001F512" # 🔒
|
||
self.LAND_AR_ICON = "\U000025AD" # ▭
|
||
self.PORT_AR_ICON = "\U000025AF" # ▯
|
||
self.INFO_ICON = "\U00002139" # ℹ
|
||
self.INFO_CLOSE_ICON = "\U00002BC5" # ⯅
|
||
self.OFFSET_ICON = "\U00002B83" # ⮃
|
||
self.ONE_DIM_ICON = "\U00002B85" # ⮅
|
||
|
||
# Determine the width and height based on the mode (img2img or txt2img)
|
||
if is_img2img:
|
||
w = self.i2i_w
|
||
h = self.i2i_h
|
||
else:
|
||
w = self.t2i_w
|
||
h = self.t2i_h
|
||
|
||
# Average number box initialize without rendering
|
||
arc_average = gr.Number(label="Current W/H Avg.", value=0, interactive=False, render=False)
|
||
|
||
# Precision input box initialize without rendering
|
||
arc_prec = gr.Number(label="Precision (px)", value=64, minimum=4, maximum=128, step=4, precision=0, render=False)
|
||
|
||
with gr.Accordion(label="Aspect Ratio and Resolution Buttons", open=True):
|
||
|
||
with gr.Column(elem_id=f'arsp__{"img" if is_img2img else "txt"}2img_container_aspect_ratio'):
|
||
|
||
# Get aspect ratios from file
|
||
self.read_aspect_ratios()
|
||
|
||
# Top row
|
||
with gr.Row(elem_id=f'arsp__{"img" if is_img2img else "txt"}2img_row_aspect_ratio'):
|
||
|
||
# Lock button
|
||
arc_lock = ToolButton(value=self.LOCK_OPEN_ICON, visible=True, variant="secondary", elem_id="arsp__arc_lock_button")
|
||
# Lock button click event handling
|
||
def toggle_lock(icon, avg, w=512, h=512):
|
||
icon = self.LOCK_OPEN_ICON if (img2img_state.is_locked if is_img2img else txt2img_state.is_locked) else self.LOCK_CLOSED_ICON
|
||
if is_img2img:
|
||
img2img_state.toggle_lock()
|
||
else:
|
||
txt2img_state.toggle_lock()
|
||
if not avg:
|
||
avg = avg_from_dims(w, h)
|
||
return icon, avg
|
||
if is_img2img:
|
||
lock_w = self.i2i_w
|
||
lock_h = self.i2i_h
|
||
else:
|
||
lock_w = self.t2i_w
|
||
lock_h = self.t2i_h
|
||
# Lock button event listener
|
||
arc_lock.click(toggle_lock,
|
||
inputs = [arc_lock, arc_average, lock_w, lock_h],
|
||
outputs = [arc_lock, arc_average],
|
||
show_progress = 'hidden')
|
||
|
||
# Initialize Aspect Ratio buttons (render=False)
|
||
ar_buttons, ar_functions = create_ar_buttons(self.ar_buttons_labels, is_img2img)
|
||
|
||
# Switch button
|
||
arc_switch = ToolButton(value=self.LAND_AR_ICON, visible=True, variant="secondary", elem_id="arsp__arc_switch_button")
|
||
# Switch button click event handling
|
||
def toggle_switch(*items, **kwargs):
|
||
ar_icons = items[:-1]
|
||
sw_icon = items[-1]
|
||
if ar_icons == tuple(self.aspect_ratios):
|
||
ar_icons = tuple(self.flipped_vals)
|
||
else:
|
||
ar_icons = tuple(self.aspect_ratios)
|
||
sw_icon = self.PORT_AR_ICON if sw_icon == self.LAND_AR_ICON else self.LAND_AR_ICON
|
||
if is_img2img:
|
||
img2img_state.toggle_switch()
|
||
else:
|
||
txt2img_state.toggle_switch()
|
||
return (*ar_icons, sw_icon)
|
||
# Switch button event listener
|
||
arc_switch.click(toggle_switch,
|
||
inputs = ar_buttons+[arc_switch],
|
||
outputs = ar_buttons+[arc_switch])
|
||
|
||
# AR buttons render
|
||
for button in ar_buttons:
|
||
button.render()
|
||
# AR buttons click event handling
|
||
with contextlib.suppress(AttributeError):
|
||
for button in ar_buttons:
|
||
# AR buttons event listener
|
||
button.click(ar_functions[button],
|
||
inputs = [arc_average, arc_prec, w, h],
|
||
outputs = [arc_average, w, h],
|
||
show_progress = 'hidden')
|
||
|
||
# Get static resolutions from file
|
||
self.read_resolutions()
|
||
|
||
# Bottom row
|
||
with gr.Row(elem_id=f'arsp__{"img" if is_img2img else "txt"}2img_row_resolutions'):
|
||
|
||
# Info button to toggle info window
|
||
arc_show_info = ToolButton(value=self.INFO_ICON, visible=True, variant="secondary", elem_id="arsp__arc_show_info_button")
|
||
arc_hide_info = ToolButton(value=self.INFO_CLOSE_ICON, visible=False, variant="secondary", elem_id="arsp__arc_hide_info_button")
|
||
### Click event handling for info window ###
|
||
##### is defined after everything else #####
|
||
|
||
# Mode button
|
||
arc_mode = ToolButton(value=self.OFFSET_ICON, visible=True, variant="secondary", elem_id="arsp__arc_mode_button")
|
||
# Mode button click event handling
|
||
def toggle_mode(icon):
|
||
icon = self.ONE_DIM_ICON if icon == self.OFFSET_ICON else self.OFFSET_ICON
|
||
if is_img2img:
|
||
img2img_state.toggle_mode()
|
||
else:
|
||
txt2img_state.toggle_mode()
|
||
return icon
|
||
# Mode button event listener
|
||
arc_mode.click(toggle_mode,
|
||
inputs = [arc_mode],
|
||
outputs = [arc_mode])
|
||
|
||
# Static res buttons
|
||
buttons, set_res_functions = create_res_buttons(zip(self.res, self.res_labels), is_img2img)
|
||
|
||
# Set up click event listeners for the buttons
|
||
with contextlib.suppress(AttributeError):
|
||
for button in buttons:
|
||
button.click(set_res_functions[button],
|
||
inputs = [arc_average],
|
||
outputs = [arc_average, w, h],
|
||
show_progress = 'hidden')
|
||
|
||
# Write button_titles.js with labels and comments read from aspect ratios and resolutions files
|
||
button_titles = [self.aspect_ratios + self.res_labels]
|
||
button_titles.append(self.aspect_ratio_comments + self.res_comments)
|
||
write_js_titles_file(button_titles)
|
||
|
||
# Information panel
|
||
with gr.Column(visible=False, variant="panel", elem_id="arsp__arc_panel") as arc_panel:
|
||
with gr.Row():
|
||
with gr.Column(scale=2, min_width=100):
|
||
|
||
# Average number box render
|
||
arc_average.render()
|
||
|
||
# Precision input box render
|
||
arc_prec.render()
|
||
|
||
# Information blurb
|
||
gr.Column(scale=1, min_width=10)
|
||
with gr.Column(scale=12):
|
||
arc_title_heading = gr.Markdown(value=
|
||
'''
|
||
### AR and Static Res buttons can be customized in the 'aspect_ratios.txt' and 'resolutions.txt' files
|
||
**Aspect Ratio buttons (Top Row)**:
|
||
(1) Averages the current width/height in the UI; (2) Offsets to the exact aspect ratio; (3) Rounds to precision.
|
||
|
||
**Static Resolution buttons (Bottom Row)**:
|
||
Recommended to use 1:1 values for these, to serve as a start point before clicking AR buttons.
|
||
|
||
**64px Precision is recommended, the same rounding applied for image "bucketing" when model training.**
|
||
'''
|
||
)
|
||
|
||
# Info panel event listeners
|
||
arc_show_info.click(
|
||
lambda: [
|
||
gr.update(visible=True),
|
||
gr.update(visible=False),
|
||
gr.update(visible=True),
|
||
],
|
||
None,
|
||
[
|
||
arc_panel,
|
||
arc_show_info,
|
||
arc_hide_info,
|
||
],
|
||
)
|
||
arc_hide_info.click(
|
||
lambda: [
|
||
gr.update(visible=False),
|
||
gr.update(visible=True),
|
||
gr.update(visible=False),
|
||
],
|
||
None,
|
||
[arc_panel, arc_show_info, arc_hide_info],
|
||
)
|
||
|
||
## Function to update the values in appropriate Width/Height fields
|
||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/7456#issuecomment-1414465888
|
||
def after_component(self, component, **kwargs):
|
||
if kwargs.get("elem_id") == "txt2img_width":
|
||
self.t2i_w = component
|
||
if kwargs.get("elem_id") == "txt2img_height":
|
||
self.t2i_h = component
|
||
if kwargs.get("elem_id") == "img2img_width":
|
||
self.i2i_w = component
|
||
if kwargs.get("elem_id") == "img2img_height":
|
||
self.i2i_h = component
|