Merge pull request #4 from altoiddealer/dev

Buttons without inheritance
This commit is contained in:
altoiddealer
2024-08-21 12:27:17 -04:00
committed by GitHub
2 changed files with 116 additions and 142 deletions

View File

@@ -1,3 +1,4 @@
import typing as t
import contextlib
from pathlib import Path
@@ -7,10 +8,25 @@ import modules.scripts as scripts
from modules.ui_components import ToolButton
from fractions import Fraction
from math import gcd, sqrt
from math import sqrt
BASE_PATH = scripts.basedir()
class ButtonState():
def __init__(self):
self.is_locked = False
self.switched = False
self.alt_mode = False
def toggle_lock(self):
self.is_locked = not self.is_locked
def toggle_switch(self):
self.switched = not self.switched
def toggle_mode(self):
self.alt_mode = not self.alt_mode
txt2img_state = ButtonState()
img2img_state = ButtonState()
# Helper functions for calculating new width/height values
def round_to_precision(val, prec):
return round(val / prec) * prec
@@ -66,125 +82,69 @@ def avg_from_dims(w, h):
return avg
## Aspect Ratio buttons
# two subclasses are necessary to properly split behavior among txt2img and img2img tabs
# txt2img AR Buttons
class txt2imgARButtons(ToolButton):
is_locked = False
switched = False
alt_mode = False
def __init__(self, value='1:1', **kwargs):
super().__init__(**kwargs)
self.value = value
def apply(self, avg, prec, w=512, h=512):
ar = self.value
n, d = map(Fraction, ar.split(':')) # split numerator and denominator
if not txt2imgARButtons.is_locked:
avg = avg_from_dims(w, h) # Get average of current width/height values
if not txt2imgARButtons.alt_mode: # True = offset, False = One dimension
w, h = dims_from_ar(avg, n, d, prec) # Calculate new w + h from avg, AR, and precision
if txt2imgARButtons.switched: # Switch results if switch mode active
def create_ar_button_function(ar:str, is_img2img:bool):
def wrapper(avg, prec, w=512, h=512):
# Determine the state based on whether it's img2img or txt2img
state = img2img_state if is_img2img else txt2img_state
n, d = map(Fraction, ar.split(':')) # Split numerator and denominator
if not state.is_locked:
avg = avg_from_dims(w, h) # Get average of current width/height values
if not state.alt_mode: # True = offset, False = One dimension
w, h = dims_from_ar(avg, n, d, prec) # Calculate new w + h from avg, AR, and precision
if state.switched: # Switch results if switch mode is active
w, h = h, w
else: # Calculate w or h from input, AR, and precision
if txt2imgARButtons.switched: # Switch results if switch mode active
w, h = calc_width(n, d, w, h, prec) # Modify width
else: # Calculate w or h from input, AR, and precision
if state.switched: # Switch results if switch mode is active
w, h = calc_width(n, d, w, h, prec) # Modify width
else:
w, h = calc_height(n, d, w, h, prec) # Modify height
w, h = calc_height(n, d, w, h, prec) # Modify height
return avg, w, h
# Toggle all buttons in subclass
@classmethod
def toggle_lock(cls):
cls.is_locked = not cls.is_locked
@classmethod
def toggle_switch(cls):
cls.switched = not cls.switched
@classmethod
def toggle_mode(cls):
cls.alt_mode = not cls.alt_mode
# img2img AR Buttons
class img2imgARButtons(ToolButton):
is_locked = False
switched = False
alt_mode = False
return wrapper
def __init__(self, value='1:1', **kwargs):
super().__init__(**kwargs)
self.value = value
def create_ar_buttons(
lst: t.Iterable[str],
is_img2img: bool,
) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]:
buttons = []
functions = {}
def apply(self, avg, prec, w=512, h=512):
ar = self.value
n, d = map(Fraction, ar.split(':'))
if not img2imgARButtons.is_locked:
avg = avg_from_dims(w, h)
if not img2imgARButtons.alt_mode:
w, h = dims_from_ar(avg, n, d, prec)
if img2imgARButtons.switched:
w, h = h, w
else:
if img2imgARButtons.switched:
w, h = calc_width(n, d, w, h, prec)
else:
w, h = calc_height(n, d, w, h, prec)
return avg, w, h
# Toggle all buttons in subclass
@classmethod
def toggle_lock(cls):
cls.is_locked = not cls.is_locked
@classmethod
def toggle_switch(cls):
cls.switched = not cls.switched
@classmethod
def toggle_mode(cls):
cls.alt_mode = not cls.alt_mode
for ar in lst:
button = ToolButton(ar, render=False)
function = create_ar_button_function(ar, is_img2img)
buttons.append(button)
functions[button] = function
return buttons, functions
## Static Resolution buttons
# two subclasses are necessary to properly split behavior among txt2img and img2img tabs
def create_res_button_function(w:int, h:int, is_img2img:bool):
def wrapper(avg):
state = img2img_state if is_img2img else txt2img_state
if not state.is_locked:
avg = avg_from_dims(w, h)
return avg, w, h
return wrapper
# txt2img Static Resolution buttons
class txt2imgResButtons(ToolButton):
is_locked = False
def create_res_buttons(
lst: t.Iterable[t.Tuple[t.List[int], str]],
is_img2img: bool,
) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]:
buttons = []
functions = {}
def __init__(self, res=(512, 512), **kwargs):
super().__init__(**kwargs)
self.w, self.h = res
for resolution, label in lst:
button = ToolButton(label)
w, h = resolution
function = create_res_button_function(w, h, is_img2img)
buttons.append(button)
functions[button] = function
def reset(self, avg):
# Get average of current width/height values
if not txt2imgResButtons.is_locked:
avg = avg_from_dims(self.w, self.h)
return avg, self.w, self.h
# Toggle all buttons in subclass
@classmethod
def toggle_lock(cls):
cls.is_locked = not cls.is_locked
# img2img Static Resolution buttons
class img2imgResButtons(ToolButton):
is_locked = False
def __init__(self, res=(512, 512), **kwargs):
super().__init__(**kwargs)
self.w, self.h = res
def reset(self, avg):
# Get average of current width/height values
if not img2imgResButtons.is_locked:
avg = avg_from_dims(self.w, self.h)
return avg, self.w, self.h
# Toggle all buttons in subclass
@classmethod
def toggle_lock(cls):
cls.is_locked = not cls.is_locked
# Hack for Gradio 4.0; see `get_component_class_id` in `venv/lib/site-packages/gradio/components/base.py`
txt2imgARButtons.__module__ = "modules.ui_components"
img2imgARButtons.__module__ = "modules.ui_components"
txt2imgResButtons.__module__ = "modules.ui_components"
img2imgResButtons.__module__ = "modules.ui_components"
return buttons, functions
# Get values for Aspect Ratios from file
def parse_aspect_ratios_file(filename):
@@ -293,7 +253,7 @@ class AspectRatioScript(scripts.Script):
if not ar_file.exists():
write_aspect_ratios_file(ar_file)
(self.aspect_ratios, self.aspect_ratio_comments, self.flipped_vals) = parse_aspect_ratios_file("aspect_ratios.txt")
self.ar_btns_labels = self.aspect_ratios
self.ar_buttons_labels = self.aspect_ratios
def read_resolutions(self):
res_file = Path(BASE_PATH, "resolutions.txt")
@@ -318,9 +278,17 @@ class AspectRatioScript(scripts.Script):
self.INFO_CLOSE_ICON = "\U00002BC5" # ⯅
self.OFFSET_ICON = "\U00002B83" # ⮃
self.ONE_DIM_ICON = "\U00002B85" # ⮅
# Determine the width and height based on the mode (img2img or txt2img)
if is_img2img:
w = self.i2i_w
h = self.i2i_h
else:
w = self.t2i_w
h = self.t2i_h
# Average number box initialize without rendering
arc_avg = gr.Number(label="Current W/H Avg.", value=0, interactive=False, render=False)
arc_average = gr.Number(label="Current W/H Avg.", value=0, interactive=False, render=False)
# Precision input box initialize without rendering
arc_prec = gr.Number(label="Precision (px)", value=64, minimum=4, maximum=128, step=4, precision=0, render=False)
@@ -339,13 +307,11 @@ class AspectRatioScript(scripts.Script):
arc_lock = ToolButton(value=self.LOCK_OPEN_ICON, visible=True, variant="secondary", elem_id="arsp__arc_lock_button")
# Lock button click event handling
def toggle_lock(icon, avg, w=512, h=512):
icon = self.LOCK_OPEN_ICON if (img2imgARButtons.is_locked if is_img2img else txt2imgARButtons.is_locked) else self.LOCK_CLOSED_ICON
icon = self.LOCK_OPEN_ICON if (img2img_state.is_locked if is_img2img else txt2img_state.is_locked) else self.LOCK_CLOSED_ICON
if is_img2img:
img2imgARButtons.toggle_lock()
img2imgResButtons.toggle_lock()
img2img_state.toggle_lock()
else:
txt2imgARButtons.toggle_lock()
txt2imgResButtons.toggle_lock()
txt2img_state.toggle_lock()
if not avg:
avg = avg_from_dims(w, h)
return icon, avg
@@ -356,13 +322,16 @@ class AspectRatioScript(scripts.Script):
lock_w = self.t2i_w
lock_h = self.t2i_h
# Lock button event listener
arc_lock.click(toggle_lock, inputs=[arc_lock, arc_avg, lock_w, lock_h], outputs=[arc_lock, arc_avg])
arc_lock.click(toggle_lock,
inputs = [arc_lock, arc_average, lock_w, lock_h],
outputs = [arc_lock, arc_average],
show_progress = 'hidden')
# Initialize Aspect Ratio buttons (render=False)
ar_btns = [img2imgARButtons(value=ar, render=False) if is_img2img else txt2imgARButtons(value=ar, render=False) for ar in self.ar_btns_labels]
ar_buttons, ar_functions = create_ar_buttons(self.ar_buttons_labels, is_img2img)
# Switch button
arc_sw = ToolButton(value=self.LAND_AR_ICON, visible=True, variant="secondary", elem_id="arsp__arc_sw_button")
arc_switch = ToolButton(value=self.LAND_AR_ICON, visible=True, variant="secondary", elem_id="arsp__arc_switch_button")
# Switch button click event handling
def toggle_switch(*items, **kwargs):
ar_icons = items[:-1]
@@ -373,27 +342,26 @@ class AspectRatioScript(scripts.Script):
ar_icons = tuple(self.aspect_ratios)
sw_icon = self.PORT_AR_ICON if sw_icon == self.LAND_AR_ICON else self.LAND_AR_ICON
if is_img2img:
img2imgARButtons.toggle_switch()
img2img_state.toggle_switch()
else:
txt2imgARButtons.toggle_switch()
txt2img_state.toggle_switch()
return (*ar_icons, sw_icon)
# Switch button event listener
arc_sw.click(toggle_switch, inputs=ar_btns+[arc_sw], outputs=ar_btns+[arc_sw])
arc_switch.click(toggle_switch,
inputs = ar_buttons+[arc_switch],
outputs = ar_buttons+[arc_switch])
# AR buttons render
for b in ar_btns:
b.render()
for button in ar_buttons:
button.render()
# AR buttons click event handling
with contextlib.suppress(AttributeError):
for b in ar_btns:
if is_img2img:
w = self.i2i_w
h = self.i2i_h
else:
w = self.t2i_w
h = self.t2i_h
for button in ar_buttons:
# AR buttons event listener
b.click(b.apply, inputs=[arc_avg, arc_prec, w, h], outputs=[arc_avg, w, h])
button.click(ar_functions[button],
inputs = [arc_average, arc_prec, w, h],
outputs = [arc_average, w, h],
show_progress = 'hidden')
# Get static resolutions from file
self.read_resolutions()
@@ -413,19 +381,25 @@ class AspectRatioScript(scripts.Script):
def toggle_mode(icon):
icon = self.ONE_DIM_ICON if icon == self.OFFSET_ICON else self.OFFSET_ICON
if is_img2img:
img2imgARButtons.toggle_mode()
img2img_state.toggle_mode()
else:
txt2imgARButtons.toggle_mode()
txt2img_state.toggle_mode()
return icon
# Mode button event listener
arc_mode.click(toggle_mode, inputs=[arc_mode], outputs=[arc_mode])
arc_mode.click(toggle_mode,
inputs = [arc_mode],
outputs = [arc_mode])
# Static res buttons
btns = [img2imgResButtons(res=res, value=label) if is_img2img else txt2imgResButtons(res=res, value=label) for res, label in zip(self.res, self.res_labels)]
# Static res buttons event listener
buttons, set_res_functions = create_res_buttons(zip(self.res, self.res_labels), is_img2img)
# Set up click event listeners for the buttons
with contextlib.suppress(AttributeError):
for b in btns:
b.click(b.reset, inputs=[arc_avg], outputs=[arc_avg, w, h])
for button in buttons:
button.click(set_res_functions[button],
inputs = [arc_average],
outputs = [arc_average, w, h],
show_progress = 'hidden')
# Write button_titles.js with labels and comments read from aspect ratios and resolutions files
button_titles = [self.aspect_ratios + self.res_labels]
@@ -438,7 +412,7 @@ class AspectRatioScript(scripts.Script):
with gr.Column(scale=2, min_width=100):
# Average number box render
arc_avg.render()
arc_average.render()
# Precision input box render
arc_prec.render()
@@ -493,4 +467,4 @@ class AspectRatioScript(scripts.Script):
if kwargs.get("elem_id") == "img2img_width":
self.i2i_w = component
if kwargs.get("elem_id") == "img2img_height":
self.i2i_h = component
self.i2i_h = component

View File

@@ -22,7 +22,7 @@
button#arsp__arc_lock_button,
button#arsp__arc_mode_button,
button#arsp__arc_sw_button,
button#arsp__arc_switch_button,
button#arsp__arc_show_info_button,
button#arsp__arc_hide_info_button {
max-width: 40px !important;