diff --git a/scripts/sd-webui-ar-plusplus.py b/scripts/sd-webui-ar-plusplus.py index 42d8e62..99a24df 100644 --- a/scripts/sd-webui-ar-plusplus.py +++ b/scripts/sd-webui-ar-plusplus.py @@ -1,3 +1,4 @@ +import typing as t import contextlib from pathlib import Path @@ -7,10 +8,25 @@ import modules.scripts as scripts from modules.ui_components import ToolButton from fractions import Fraction -from math import gcd, sqrt +from math import sqrt BASE_PATH = scripts.basedir() +class ButtonState(): + def __init__(self): + self.is_locked = False + self.switched = False + self.alt_mode = False + def toggle_lock(self): + self.is_locked = not self.is_locked + def toggle_switch(self): + self.switched = not self.switched + def toggle_mode(self): + self.alt_mode = not self.alt_mode + +txt2img_state = ButtonState() +img2img_state = ButtonState() + # Helper functions for calculating new width/height values def round_to_precision(val, prec): return round(val / prec) * prec @@ -66,125 +82,69 @@ def avg_from_dims(w, h): return avg ## Aspect Ratio buttons -# two subclasses are necessary to properly split behavior among txt2img and img2img tabs - -# txt2img AR Buttons -class txt2imgARButtons(ToolButton): - is_locked = False - switched = False - alt_mode = False - - def __init__(self, value='1:1', **kwargs): - super().__init__(**kwargs) - self.value = value - - def apply(self, avg, prec, w=512, h=512): - ar = self.value - n, d = map(Fraction, ar.split(':')) # split numerator and denominator - if not txt2imgARButtons.is_locked: - avg = avg_from_dims(w, h) # Get average of current width/height values - if not txt2imgARButtons.alt_mode: # True = offset, False = One dimension - w, h = dims_from_ar(avg, n, d, prec) # Calculate new w + h from avg, AR, and precision - if txt2imgARButtons.switched: # Switch results if switch mode active +def create_ar_button_function(ar:str, is_img2img:bool): + def wrapper(avg, prec, w=512, h=512): + # Determine the state based on whether it's img2img or txt2img + state = img2img_state if is_img2img else txt2img_state + + n, d = map(Fraction, ar.split(':')) # Split numerator and denominator + + if not state.is_locked: + avg = avg_from_dims(w, h) # Get average of current width/height values + + if not state.alt_mode: # True = offset, False = One dimension + w, h = dims_from_ar(avg, n, d, prec) # Calculate new w + h from avg, AR, and precision + if state.switched: # Switch results if switch mode is active w, h = h, w - else: # Calculate w or h from input, AR, and precision - if txt2imgARButtons.switched: # Switch results if switch mode active - w, h = calc_width(n, d, w, h, prec) # Modify width + else: # Calculate w or h from input, AR, and precision + if state.switched: # Switch results if switch mode is active + w, h = calc_width(n, d, w, h, prec) # Modify width else: - w, h = calc_height(n, d, w, h, prec) # Modify height + w, h = calc_height(n, d, w, h, prec) # Modify height + return avg, w, h - # Toggle all buttons in subclass - @classmethod - def toggle_lock(cls): - cls.is_locked = not cls.is_locked - @classmethod - def toggle_switch(cls): - cls.switched = not cls.switched - @classmethod - def toggle_mode(cls): - cls.alt_mode = not cls.alt_mode -# img2img AR Buttons -class img2imgARButtons(ToolButton): - is_locked = False - switched = False - alt_mode = False + return wrapper - def __init__(self, value='1:1', **kwargs): - super().__init__(**kwargs) - self.value = value +def create_ar_buttons( + lst: t.Iterable[str], + is_img2img: bool, +) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]: + buttons = [] + functions = {} - def apply(self, avg, prec, w=512, h=512): - ar = self.value - n, d = map(Fraction, ar.split(':')) - if not img2imgARButtons.is_locked: - avg = avg_from_dims(w, h) - if not img2imgARButtons.alt_mode: - w, h = dims_from_ar(avg, n, d, prec) - if img2imgARButtons.switched: - w, h = h, w - else: - if img2imgARButtons.switched: - w, h = calc_width(n, d, w, h, prec) - else: - w, h = calc_height(n, d, w, h, prec) - return avg, w, h - # Toggle all buttons in subclass - @classmethod - def toggle_lock(cls): - cls.is_locked = not cls.is_locked - @classmethod - def toggle_switch(cls): - cls.switched = not cls.switched - @classmethod - def toggle_mode(cls): - cls.alt_mode = not cls.alt_mode + for ar in lst: + button = ToolButton(ar, render=False) + function = create_ar_button_function(ar, is_img2img) + buttons.append(button) + functions[button] = function + return buttons, functions ## Static Resolution buttons -# two subclasses are necessary to properly split behavior among txt2img and img2img tabs +def create_res_button_function(w:int, h:int, is_img2img:bool): + def wrapper(avg): + state = img2img_state if is_img2img else txt2img_state + if not state.is_locked: + avg = avg_from_dims(w, h) + return avg, w, h + return wrapper -# txt2img Static Resolution buttons -class txt2imgResButtons(ToolButton): - is_locked = False +def create_res_buttons( + lst: t.Iterable[t.Tuple[t.List[int], str]], + is_img2img: bool, +) -> t.Tuple[t.List[ToolButton], t.Dict[ToolButton, t.Callable]]: + buttons = [] + functions = {} - def __init__(self, res=(512, 512), **kwargs): - super().__init__(**kwargs) - self.w, self.h = res + for resolution, label in lst: + button = ToolButton(label) + w, h = resolution + function = create_res_button_function(w, h, is_img2img) + buttons.append(button) + functions[button] = function - def reset(self, avg): - # Get average of current width/height values - if not txt2imgResButtons.is_locked: - avg = avg_from_dims(self.w, self.h) - return avg, self.w, self.h - # Toggle all buttons in subclass - @classmethod - def toggle_lock(cls): - cls.is_locked = not cls.is_locked - -# img2img Static Resolution buttons -class img2imgResButtons(ToolButton): - is_locked = False - - def __init__(self, res=(512, 512), **kwargs): - super().__init__(**kwargs) - self.w, self.h = res - - def reset(self, avg): - # Get average of current width/height values - if not img2imgResButtons.is_locked: - avg = avg_from_dims(self.w, self.h) - return avg, self.w, self.h - # Toggle all buttons in subclass - @classmethod - def toggle_lock(cls): - cls.is_locked = not cls.is_locked - -# Hack for Gradio 4.0; see `get_component_class_id` in `venv/lib/site-packages/gradio/components/base.py` -txt2imgARButtons.__module__ = "modules.ui_components" -img2imgARButtons.__module__ = "modules.ui_components" -txt2imgResButtons.__module__ = "modules.ui_components" -img2imgResButtons.__module__ = "modules.ui_components" + return buttons, functions # Get values for Aspect Ratios from file def parse_aspect_ratios_file(filename): @@ -293,7 +253,7 @@ class AspectRatioScript(scripts.Script): if not ar_file.exists(): write_aspect_ratios_file(ar_file) (self.aspect_ratios, self.aspect_ratio_comments, self.flipped_vals) = parse_aspect_ratios_file("aspect_ratios.txt") - self.ar_btns_labels = self.aspect_ratios + self.ar_buttons_labels = self.aspect_ratios def read_resolutions(self): res_file = Path(BASE_PATH, "resolutions.txt") @@ -318,9 +278,17 @@ class AspectRatioScript(scripts.Script): self.INFO_CLOSE_ICON = "\U00002BC5" # ⯅ self.OFFSET_ICON = "\U00002B83" # ⮃ self.ONE_DIM_ICON = "\U00002B85" # ⮅ - + + # Determine the width and height based on the mode (img2img or txt2img) + if is_img2img: + w = self.i2i_w + h = self.i2i_h + else: + w = self.t2i_w + h = self.t2i_h + # Average number box initialize without rendering - arc_avg = gr.Number(label="Current W/H Avg.", value=0, interactive=False, render=False) + arc_average = gr.Number(label="Current W/H Avg.", value=0, interactive=False, render=False) # Precision input box initialize without rendering arc_prec = gr.Number(label="Precision (px)", value=64, minimum=4, maximum=128, step=4, precision=0, render=False) @@ -339,13 +307,11 @@ class AspectRatioScript(scripts.Script): arc_lock = ToolButton(value=self.LOCK_OPEN_ICON, visible=True, variant="secondary", elem_id="arsp__arc_lock_button") # Lock button click event handling def toggle_lock(icon, avg, w=512, h=512): - icon = self.LOCK_OPEN_ICON if (img2imgARButtons.is_locked if is_img2img else txt2imgARButtons.is_locked) else self.LOCK_CLOSED_ICON + icon = self.LOCK_OPEN_ICON if (img2img_state.is_locked if is_img2img else txt2img_state.is_locked) else self.LOCK_CLOSED_ICON if is_img2img: - img2imgARButtons.toggle_lock() - img2imgResButtons.toggle_lock() + img2img_state.toggle_lock() else: - txt2imgARButtons.toggle_lock() - txt2imgResButtons.toggle_lock() + txt2img_state.toggle_lock() if not avg: avg = avg_from_dims(w, h) return icon, avg @@ -356,13 +322,16 @@ class AspectRatioScript(scripts.Script): lock_w = self.t2i_w lock_h = self.t2i_h # Lock button event listener - arc_lock.click(toggle_lock, inputs=[arc_lock, arc_avg, lock_w, lock_h], outputs=[arc_lock, arc_avg]) + arc_lock.click(toggle_lock, + inputs = [arc_lock, arc_average, lock_w, lock_h], + outputs = [arc_lock, arc_average], + show_progress = 'hidden') # Initialize Aspect Ratio buttons (render=False) - ar_btns = [img2imgARButtons(value=ar, render=False) if is_img2img else txt2imgARButtons(value=ar, render=False) for ar in self.ar_btns_labels] + ar_buttons, ar_functions = create_ar_buttons(self.ar_buttons_labels, is_img2img) # Switch button - arc_sw = ToolButton(value=self.LAND_AR_ICON, visible=True, variant="secondary", elem_id="arsp__arc_sw_button") + arc_switch = ToolButton(value=self.LAND_AR_ICON, visible=True, variant="secondary", elem_id="arsp__arc_switch_button") # Switch button click event handling def toggle_switch(*items, **kwargs): ar_icons = items[:-1] @@ -373,27 +342,26 @@ class AspectRatioScript(scripts.Script): ar_icons = tuple(self.aspect_ratios) sw_icon = self.PORT_AR_ICON if sw_icon == self.LAND_AR_ICON else self.LAND_AR_ICON if is_img2img: - img2imgARButtons.toggle_switch() + img2img_state.toggle_switch() else: - txt2imgARButtons.toggle_switch() + txt2img_state.toggle_switch() return (*ar_icons, sw_icon) # Switch button event listener - arc_sw.click(toggle_switch, inputs=ar_btns+[arc_sw], outputs=ar_btns+[arc_sw]) + arc_switch.click(toggle_switch, + inputs = ar_buttons+[arc_switch], + outputs = ar_buttons+[arc_switch]) # AR buttons render - for b in ar_btns: - b.render() + for button in ar_buttons: + button.render() # AR buttons click event handling with contextlib.suppress(AttributeError): - for b in ar_btns: - if is_img2img: - w = self.i2i_w - h = self.i2i_h - else: - w = self.t2i_w - h = self.t2i_h + for button in ar_buttons: # AR buttons event listener - b.click(b.apply, inputs=[arc_avg, arc_prec, w, h], outputs=[arc_avg, w, h]) + button.click(ar_functions[button], + inputs = [arc_average, arc_prec, w, h], + outputs = [arc_average, w, h], + show_progress = 'hidden') # Get static resolutions from file self.read_resolutions() @@ -413,19 +381,25 @@ class AspectRatioScript(scripts.Script): def toggle_mode(icon): icon = self.ONE_DIM_ICON if icon == self.OFFSET_ICON else self.OFFSET_ICON if is_img2img: - img2imgARButtons.toggle_mode() + img2img_state.toggle_mode() else: - txt2imgARButtons.toggle_mode() + txt2img_state.toggle_mode() return icon # Mode button event listener - arc_mode.click(toggle_mode, inputs=[arc_mode], outputs=[arc_mode]) + arc_mode.click(toggle_mode, + inputs = [arc_mode], + outputs = [arc_mode]) # Static res buttons - btns = [img2imgResButtons(res=res, value=label) if is_img2img else txt2imgResButtons(res=res, value=label) for res, label in zip(self.res, self.res_labels)] - # Static res buttons event listener + buttons, set_res_functions = create_res_buttons(zip(self.res, self.res_labels), is_img2img) + + # Set up click event listeners for the buttons with contextlib.suppress(AttributeError): - for b in btns: - b.click(b.reset, inputs=[arc_avg], outputs=[arc_avg, w, h]) + for button in buttons: + button.click(set_res_functions[button], + inputs = [arc_average], + outputs = [arc_average, w, h], + show_progress = 'hidden') # Write button_titles.js with labels and comments read from aspect ratios and resolutions files button_titles = [self.aspect_ratios + self.res_labels] @@ -438,7 +412,7 @@ class AspectRatioScript(scripts.Script): with gr.Column(scale=2, min_width=100): # Average number box render - arc_avg.render() + arc_average.render() # Precision input box render arc_prec.render() @@ -493,4 +467,4 @@ class AspectRatioScript(scripts.Script): if kwargs.get("elem_id") == "img2img_width": self.i2i_w = component if kwargs.get("elem_id") == "img2img_height": - self.i2i_h = component \ No newline at end of file + self.i2i_h = component diff --git a/style.css b/style.css index 45b7e98..ab0ddf7 100644 --- a/style.css +++ b/style.css @@ -22,7 +22,7 @@ button#arsp__arc_lock_button, button#arsp__arc_mode_button, -button#arsp__arc_sw_button, +button#arsp__arc_switch_button, button#arsp__arc_show_info_button, button#arsp__arc_hide_info_button { max-width: 40px !important;