Files
--sd-webui-ar-plusplus/scripts/sd-webui-ar-plusplus.py
altoiddealer 7f26881361 Buttons without inheritance
Most credit goes to [LEv145](https://github.com/LEv145) for solving the inheritance issue in sd-webui-forge (due to their implementation of Gradio 4)
2024-08-21 12:11:25 -04:00

471 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import typing as t
import contextlib
from pathlib import Path
import gradio as gr
import modules.scripts as scripts
from modules.ui_components import ToolButton
from fractions import Fraction
from math import sqrt
BASE_PATH = scripts.basedir()
class ButtonState():
def __init__(self):
self.is_locked = False
self.switched = False
self.alt_mode = False
def toggle_lock(self):
self.is_locked = not self.is_locked
def toggle_switch(self):
self.switched = not self.switched
def toggle_mode(self):
self.alt_mode = not self.alt_mode
txt2img_state = ButtonState()
img2img_state = ButtonState()
# Helper functions for calculating new width/height values
def round_to_precision(val, prec):
return round(val / prec) * prec
def res_to_model_fit(avg, w, h, prec):
mp = w * h
mp_target = avg * avg
scale = sqrt(mp_target / mp)
w = int(round_to_precision(w * scale, prec))
h = int(round_to_precision(h * scale, prec))
return w, h
def calc_width(n, d, w, h, prec):
ar = round((n / d), 2) # Convert AR parts to fraction
if ar > 1.0:
h = w / ar
elif ar < 1.0:
w = h * ar
else:
new_value = max([w, h])
w, h = new_value, new_value
w = int(round_to_precision((w + prec / 2), prec))
h = int(round_to_precision((h + prec / 2), prec))
return w, h
def calc_height(n, d, w, h, prec):
ar = round((n / d), 2) # Convert AR parts to fraction
if ar > 1.0:
w = h * ar
elif ar < 1.0:
h = w / ar
else:
new_value = min([w, h])
w, h = new_value, new_value
w = int(round_to_precision((w + prec / 2), prec))
h = int(round_to_precision((h + prec / 2), prec))
return w, h
def dims_from_ar(avg, n, d, prec):
doubleavg = avg * 2
ar_sum = n+d
# calculate width and height by factoring average with aspect ratio
w = round((n / ar_sum) * doubleavg)
h = round((d / ar_sum) * doubleavg)
# Round to correct megapixel precision
w, h = res_to_model_fit(avg, w, h, prec)
return w, h
def avg_from_dims(w, h):
avg = (w + h) // 2
if (w + h) % 2 != 0:
avg += 1
return avg
## Aspect Ratio buttons
def create_ar_button_function(ar:str, is_img2img:bool):
def wrapper(avg, prec, w=512, h=512):
# Determine the state based on whether it's img2img or txt2img
state = img2img_state if is_img2img else txt2img_state
n, d = map(Fraction, ar.split(':')) # Split numerator and denominator
if not state.is_locked:
avg = avg_from_dims(w, h) # Get average of current width/height values
if not state.alt_mode: # True = offset, False = One dimension
w, h = dims_from_ar(avg, n, d, prec) # Calculate new w + h from avg, AR, and precision
if state.switched: # Switch results if switch mode is active
w, h = h, w
else: # Calculate w or h from input, AR, and precision
if state.switched: # Switch results if switch mode is active
w, h = calc_width(n, d, w, h, prec) # Modify width
else:
w, h = calc_height(n, d, w, h, prec) # Modify height
return avg, w, h
return wrapper
def create_ar_buttons(
lst: t.Iterable[str],
is_img2img: bool,
) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]:
buttons = []
functions = {}
for ar in lst:
button = ToolButton(ar, render=False)
function = create_ar_button_function(ar, is_img2img)
buttons.append(button)
functions[button] = function
return buttons, functions
## Static Resolution buttons
def create_res_button_function(w:int, h:int, is_img2img:bool):
def wrapper(avg):
state = img2img_state if is_img2img else txt2img_state
if not state.is_locked:
avg = avg_from_dims(w, h)
return avg, w, h
return wrapper
def create_res_buttons(
lst: t.Iterable[t.Tuple[t.List[int], str]],
is_img2img: bool,
) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]:
buttons = []
functions = {}
for resolution, label in lst:
button = ToolButton(label)
w, h = resolution
function = create_res_button_function(w, h, is_img2img)
buttons.append(button)
functions[button] = function
return buttons, functions
# Get values for Aspect Ratios from file
def parse_aspect_ratios_file(filename):
values, flipvals, comments = [], [], []
file = Path(BASE_PATH, filename)
if not file.exists():
return values, comments, flipvals
with open(file, "r", encoding="utf-8") as f:
lines = f.readlines()
if not lines:
return values, comments, flipvals
for line in lines:
if line.startswith("#"):
continue
value = line.strip()
comment = ""
if "#" in value:
value, comment = value.split("#")
value = value.strip()
values.append(value)
comments.append(comment)
comp1, comp2 = value.split(':')
flipval = f"{comp2}:{comp1}"
flipvals.append(flipval)
return values, comments, flipvals
# Get values for Static Resolutions from file
def parse_resolutions_file(filename):
labels, values, comments = [], [], []
file = Path(BASE_PATH, filename)
if not file.exists():
return labels, values, comments
with open(file, "r", encoding="utf-8") as f:
lines = f.readlines()
if not lines:
return labels, values, comments
for line in lines:
if line.startswith("#"):
continue
label, width, height = line.strip().split(",")
comment = ""
if "#" in height:
height, comment = height.split("#")
resolution = (width, height)
labels.append(label)
values.append(resolution)
comments.append(comment)
return labels, values, comments
def write_aspect_ratios_file(filename):
aspect_ratios = [
"1:1 # Square\n",
"4:3 # Television Photography\n",
"3:2 # Photography\n",
"8:5 # Widescreen Displays\n",
"16:9 # Widescreen Television\n",
"21:9 # Ultrawide Cinematography"
]
with open(filename, "w", encoding="utf-8") as f:
f.writelines(aspect_ratios)
def write_resolutions_file(filename):
resolutions = [
"512, 512, 512 # 512x512\n",
"768, 768, 768 # 768x768\n",
"1024, 1024, 1024 # 1024x1024\n",
"1280, 1280, 1280 # 1280x1280\n",
"1536, 1536, 1536 # 1536x1536\n",
"2048, 2048, 2048 # 2048x2048",
]
with open(filename, "w", encoding="utf-8") as f:
f.writelines(resolutions)
def write_js_titles_file(button_titles):
filename = Path(BASE_PATH, "javascript", "button_titles.js")
content = ["// Do not put custom titles here. This file is overwritten each time the WebUI is started.\n"]
content.append("arsp__ar_button_titles = {\n")
counter = 0
while counter < len(button_titles[0]):
content.append(f' " {button_titles[0][counter]}" : "{button_titles[1][counter]}",\n')
counter = counter + 1
content.append("}")
with open(filename, "w", encoding="utf-8") as f:
f.writelines(content)
class AspectRatioScript(scripts.Script):
def read_aspect_ratios(self):
ar_file = Path(BASE_PATH, "aspect_ratios.txt")
if not ar_file.exists():
write_aspect_ratios_file(ar_file)
(self.aspect_ratios, self.aspect_ratio_comments, self.flipped_vals) = parse_aspect_ratios_file("aspect_ratios.txt")
self.ar_buttons_labels = self.aspect_ratios
def read_resolutions(self):
res_file = Path(BASE_PATH, "resolutions.txt")
if not res_file.exists():
write_resolutions_file(res_file)
self.res_labels, res, self.res_comments = parse_resolutions_file("resolutions.txt")
self.res = [list(map(int, r)) for r in res]
def title(self):
return "Aspect Ratio picker"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
self.LOCK_OPEN_ICON = "\U0001F513" # 🔓
self.LOCK_CLOSED_ICON = "\U0001F512" # 🔒
self.LAND_AR_ICON = "\U000025AD" # ▭
self.PORT_AR_ICON = "\U000025AF" # ▯
self.INFO_ICON = "\U00002139" #
self.INFO_CLOSE_ICON = "\U00002BC5" # ⯅
self.OFFSET_ICON = "\U00002B83" # ⮃
self.ONE_DIM_ICON = "\U00002B85" # ⮅
# Determine the width and height based on the mode (img2img or txt2img)
if is_img2img:
w = self.i2i_w
h = self.i2i_h
else:
w = self.t2i_w
h = self.t2i_h
# Average number box initialize without rendering
arc_average = gr.Number(label="Current W/H Avg.", value=0, interactive=False, render=False)
# Precision input box initialize without rendering
arc_prec = gr.Number(label="Precision (px)", value=64, minimum=4, maximum=128, step=4, precision=0, render=False)
with gr.Accordion(label="Aspect Ratio and Resolution Buttons", open=True):
with gr.Column(elem_id=f'arsp__{"img" if is_img2img else "txt"}2img_container_aspect_ratio'):
# Get aspect ratios from file
self.read_aspect_ratios()
# Top row
with gr.Row(elem_id=f'arsp__{"img" if is_img2img else "txt"}2img_row_aspect_ratio'):
# Lock button
arc_lock = ToolButton(value=self.LOCK_OPEN_ICON, visible=True, variant="secondary", elem_id="arsp__arc_lock_button")
# Lock button click event handling
def toggle_lock(icon, avg, w=512, h=512):
icon = self.LOCK_OPEN_ICON if (img2img_state.is_locked if is_img2img else txt2img_state.is_locked) else self.LOCK_CLOSED_ICON
if is_img2img:
img2img_state.toggle_lock()
else:
txt2img_state.toggle_lock()
if not avg:
avg = avg_from_dims(w, h)
return icon, avg
if is_img2img:
lock_w = self.i2i_w
lock_h = self.i2i_h
else:
lock_w = self.t2i_w
lock_h = self.t2i_h
# Lock button event listener
arc_lock.click(toggle_lock,
inputs = [arc_lock, arc_average, lock_w, lock_h],
outputs = [arc_lock, arc_average],
show_progress = 'hidden')
# Initialize Aspect Ratio buttons (render=False)
ar_buttons, ar_functions = create_ar_buttons(self.ar_buttons_labels, is_img2img)
# Switch button
arc_switch = ToolButton(value=self.LAND_AR_ICON, visible=True, variant="secondary", elem_id="arsp__arc_switch_button")
# Switch button click event handling
def toggle_switch(*items, **kwargs):
ar_icons = items[:-1]
sw_icon = items[-1]
if ar_icons == tuple(self.aspect_ratios):
ar_icons = tuple(self.flipped_vals)
else:
ar_icons = tuple(self.aspect_ratios)
sw_icon = self.PORT_AR_ICON if sw_icon == self.LAND_AR_ICON else self.LAND_AR_ICON
if is_img2img:
img2img_state.toggle_switch()
else:
txt2img_state.toggle_switch()
return (*ar_icons, sw_icon)
# Switch button event listener
arc_switch.click(toggle_switch,
inputs = ar_buttons+[arc_switch],
outputs = ar_buttons+[arc_switch])
# AR buttons render
for button in ar_buttons:
button.render()
# AR buttons click event handling
with contextlib.suppress(AttributeError):
for button in ar_buttons:
# AR buttons event listener
button.click(ar_functions[button],
inputs = [arc_average, arc_prec, w, h],
outputs = [arc_average, w, h],
show_progress = 'hidden')
# Get static resolutions from file
self.read_resolutions()
# Bottom row
with gr.Row(elem_id=f'arsp__{"img" if is_img2img else "txt"}2img_row_resolutions'):
# Info button to toggle info window
arc_show_info = ToolButton(value=self.INFO_ICON, visible=True, variant="secondary", elem_id="arsp__arc_show_info_button")
arc_hide_info = ToolButton(value=self.INFO_CLOSE_ICON, visible=False, variant="secondary", elem_id="arsp__arc_hide_info_button")
### Click event handling for info window ###
##### is defined after everything else #####
# Mode button
arc_mode = ToolButton(value=self.OFFSET_ICON, visible=True, variant="secondary", elem_id="arsp__arc_mode_button")
# Mode button click event handling
def toggle_mode(icon):
icon = self.ONE_DIM_ICON if icon == self.OFFSET_ICON else self.OFFSET_ICON
if is_img2img:
img2img_state.toggle_mode()
else:
txt2img_state.toggle_mode()
return icon
# Mode button event listener
arc_mode.click(toggle_mode,
inputs = [arc_mode],
outputs = [arc_mode])
# Static res buttons
buttons, set_res_functions = create_res_buttons(zip(self.res, self.res_labels), is_img2img)
# Set up click event listeners for the buttons
with contextlib.suppress(AttributeError):
for button in buttons:
button.click(set_res_functions[button],
inputs = [arc_average],
outputs = [arc_average, w, h],
show_progress = 'hidden')
# Write button_titles.js with labels and comments read from aspect ratios and resolutions files
button_titles = [self.aspect_ratios + self.res_labels]
button_titles.append(self.aspect_ratio_comments + self.res_comments)
write_js_titles_file(button_titles)
# Information panel
with gr.Column(visible=False, variant="panel", elem_id="arsp__arc_panel") as arc_panel:
with gr.Row():
with gr.Column(scale=2, min_width=100):
# Average number box render
arc_average.render()
# Precision input box render
arc_prec.render()
# Information blurb
gr.Column(scale=1, min_width=10)
with gr.Column(scale=12):
arc_title_heading = gr.Markdown(value=
'''
### AR and Static Res buttons can be customized in the 'aspect_ratios.txt' and 'resolutions.txt' files
**Aspect Ratio buttons (Top Row)**:
(1) Averages the current width/height in the UI; (2) Offsets to the exact aspect ratio; (3) Rounds to precision.
**Static Resolution buttons (Bottom Row)**:
Recommended to use 1:1 values for these, to serve as a start point before clicking AR buttons.
**64px Precision is recommended, the same rounding applied for image "bucketing" when model training.**
'''
)
# Info panel event listeners
arc_show_info.click(
lambda: [
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True),
],
None,
[
arc_panel,
arc_show_info,
arc_hide_info,
],
)
arc_hide_info.click(
lambda: [
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
],
None,
[arc_panel, arc_show_info, arc_hide_info],
)
## Function to update the values in appropriate Width/Height fields
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/7456#issuecomment-1414465888
def after_component(self, component, **kwargs):
if kwargs.get("elem_id") == "txt2img_width":
self.t2i_w = component
if kwargs.get("elem_id") == "txt2img_height":
self.t2i_h = component
if kwargs.get("elem_id") == "img2img_width":
self.i2i_w = component
if kwargs.get("elem_id") == "img2img_height":
self.i2i_h = component