diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d1c46a4b..d658ad10 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -22,6 +22,8 @@ class NetworkModuleOFT(network.NetworkModule): self.org_module: list[torch.Module] = [self.sd_module] self.scale = 1.0 + self.is_kohya = False + self.is_boft = False # kohya-ss if "oft_blocks" in weights.w.keys(): @@ -29,13 +31,19 @@ class NetworkModuleOFT(network.NetworkModule): self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - # LyCORIS + # LyCORIS OFT elif "oft_diag" in weights.w.keys(): - self.is_kohya = False self.oft_blocks = weights.w["oft_diag"] # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + # LyCORIS BOFT + if weights.w["oft_diag"].dim() == 4: + self.is_boft = True + self.rescale = weights.w.get('rescale', None) + if self.rescale is not None: + self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported @@ -51,6 +59,13 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim self.num_blocks = self.dim self.block_size = self.out_dim // self.dim + elif self.is_boft: + self.constraint = None + self.boft_m = weights.w["oft_diag"].shape[0] + self.block_num = weights.w["oft_diag"].shape[1] + self.block_size = weights.w["oft_diag"].shape[2] + self.boft_b = self.block_size + #self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim) else: self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) @@ -68,14 +83,37 @@ class NetworkModuleOFT(network.NetworkModule): R = oft_blocks.to(orig_weight.device) - # This errors out for MultiheadAttention, might need to be handled up-stream - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if not self.is_boft: + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + else: + # TODO: determine correct value for scale + scale = 1.0 + m = self.boft_m + b = self.boft_b + r_b = b // 2 + inp = orig_weight + for i in range(m): + bi = R[i] # b_num, b_size, b_size + if i == 0: + # Apply multiplier/scale and rescale into first weight + bi = bi * scale + (1 - scale) * eye + inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) + inp = rearrange(inp, "(d b) ... -> d b ...", b=b) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = rearrange(inp, "d b ... -> (d b) ...") + inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) + merged_weight = inp + + # Rescale mechanism + if self.rescale is not None: + merged_weight = self.rescale.to(merged_weight) * merged_weight updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape diff --git a/modules/processing.py b/modules/processing.py index f4aa165d..d208a922 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -74,16 +74,18 @@ def uncrop(image, dest_size, paste_loc): def apply_overlay(image, paste_loc, overlay): if overlay is None: - return image + return image, image.copy() if paste_loc is not None: image = uncrop(image, (overlay.width, overlay.height), paste_loc) + original_denoised_image = image.copy() + image = image.convert('RGBA') image.alpha_composite(overlay) image = image.convert('RGB') - return image + return image, original_denoised_image def create_binary_mask(image, round=True): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): @@ -1021,7 +1023,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, overlay_image) + image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) @@ -1029,12 +1031,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: # that is being composited over the original image, # we need to keep the original image around # and use it in the composite step. - original_denoised_image = image.copy() - - if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) - - image = apply_overlay(image, p.paste_to, overlay_image) + image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image) if p.scripts is not None: pp = scripts.PostprocessImageArgs(image) diff --git a/modules/shared_options.py b/modules/shared_options.py index 25b47aa1..bb3752ba 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -284,6 +284,7 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), { "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), + "open_dir_button_choice": OptionInfo("Subdirectory", "What directory the [📂] button opens", gr.Radio, {"choices": ["Output Root", "Subdirectory", "Subdirectory (even temp dir)"]}), })) options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), { diff --git a/modules/ui_common.py b/modules/ui_common.py index 29fe7d0e..cf1b8b32 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -9,7 +9,7 @@ import sys import gradio as gr import subprocess as sp -from modules import call_queue, shared +from modules import call_queue, shared, ui_tempdir from modules.infotext_utils import image_from_url_text import modules.images from modules.ui_components import ToolButton @@ -164,29 +164,43 @@ class OutputPanel: def create_output_panel(tabname, outdir, toprow=None): res = OutputPanel() - def open_folder(f): + def open_folder(f, images=None, index=None): + if shared.cmd_opts.hide_ui_dir_config: + return + + try: + if 'Sub' in shared.opts.open_dir_button_choice: + image_dir = os.path.split(images[index]["name"].rsplit('?', 1)[0])[0] + if 'temp' in shared.opts.open_dir_button_choice or not ui_tempdir.is_gradio_temp_path(image_dir): + f = image_dir + except Exception: + pass + if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + msg = f'Folder "{f}" does not exist. After you create an image, the folder will be created.' + print(msg) + gr.Info(msg) return elif not os.path.isdir(f): - print(f""" + msg = f""" WARNING An open_folder request was made with an argument that is not a folder. This could be an error or a malicious attempt to run code on your computer. Requested path was: {f} -""", file=sys.stderr) +""" + print(msg, file=sys.stderr) + gr.Warning(msg) return - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) with gr.Column(elem_id=f"{tabname}_results"): if toprow: @@ -213,8 +227,12 @@ Requested path was: {f} res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.") open_folder_button.click( - fn=lambda: open_folder(shared.opts.outdir_samples or outdir), - inputs=[], + fn=lambda images, index: open_folder(shared.opts.outdir_samples or outdir, images, index), + _js="(y, w) => [y, selected_gallery_index()]", + inputs=[ + res.gallery, + open_folder_button, # placeholder for index + ], outputs=[], ) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c03b9f08..6874a024 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -472,7 +472,7 @@ class ExtraNetworksPage: return f"" - def create_card_view_html(self, tabname: str) -> str: + def create_card_view_html(self, tabname: str, *, none_message) -> str: """Generates HTML for the network Card View section for a tab. This HTML goes into the `extra-networks-pane.html`
with @@ -480,6 +480,7 @@ class ExtraNetworksPage: Args: tabname: The name of the active tab. + none_message: HTML text to show when there are no cards. Returns: HTML formatted string. @@ -490,24 +491,28 @@ class ExtraNetworksPage: if res == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) - res = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + res = none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs) return res - def create_html(self, tabname): + def create_html(self, tabname, *, empty=False): """Generates an HTML string for the current pane. The generated HTML uses `extra-networks-pane.html` as a template. Args: tabname: The name of the active tab. + empty: create an empty HTML page with no items Returns: HTML formatted string. """ self.lister.reset() self.metadata = {} - self.items = {x["name"]: x for x in self.list_items()} + + items_list = [] if empty else self.list_items() + self.items = {x["name"]: x for x in items_list} + # Populate the instance metadata for each item. for item in self.items.values(): metadata = item.get("metadata") @@ -536,7 +541,7 @@ class ExtraNetworksPage: "tree_view_btn_extra_class": tree_view_btn_extra_class, "tree_view_div_extra_class": tree_view_div_extra_class, "tree_html": self.create_tree_view_html(tabname), - "items_html": self.create_card_view_html(tabname), + "items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None), } ) @@ -655,7 +660,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): pass elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" - page_elem = gr.HTML('Loading...', elem_id=elem_id) + page_elem = gr.HTML(page.create_html(tabname, empty=True), elem_id=elem_id) ui.pages.append(page_elem) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 91f40ea4..621ed1ec 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -81,3 +81,18 @@ def cleanup_tmpdr(): filename = os.path.join(root, name) os.remove(filename) + + +def is_gradio_temp_path(path): + """ + Check if the path is a temp dir used by gradio + """ + path = Path(path) + if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir): + return True + if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"): + if path.is_relative_to(gradio_temp_dir): + return True + if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"): + return True + return False