From a82d5d177c065f8a6251850e233dca97f15c13ac Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:53:58 +0100 Subject: [PATCH] restore lora version filtering (#1885) Added Flux to lora types in extra networks UI, so user can set. Loras versioned first by user-set type, if any. Falls back to heuristics - these are much more reliable than the removed old A1111 tests and in case of no match default to Unknown (always displayed). Filtering is done based on UI setting. 'all' setting does not filter. Filters lora lists on change. Removed unused 'lora_hide_unknown_for_versions' setting. --- extensions-builtin/sd_forge_lora/network.py | 26 +++++++++++++ .../sd_forge_lora/scripts/lora_script.py | 1 - .../sd_forge_lora/ui_edit_user_metadata.py | 4 +- .../sd_forge_lora/ui_extra_networks_lora.py | 17 +++++++-- javascript/extraNetworks.js | 38 +++++++++++++++++++ modules/ui_extra_networks.py | 4 +- modules_forge/main_entry.py | 3 +- 7 files changed, 83 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/sd_forge_lora/network.py b/extensions-builtin/sd_forge_lora/network.py index 63de6562..ec6bf663 100644 --- a/extensions-builtin/sd_forge_lora/network.py +++ b/extensions-builtin/sd_forge_lora/network.py @@ -1,10 +1,18 @@ import os +import enum from modules import sd_models, cache, errors, hashes, shared metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 +# SD3 = 5 + Flux = 6 class NetworkOnDisk: def __init__(self, name, filename): @@ -41,6 +49,24 @@ class NetworkOnDisk: '' ) + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('modelspec.implementation', '')) == 'https://github.com/black-forest-labs/flux': + return SdVersion.Flux + elif str(self.metadata.get('modelspec.architecture', '')) == 'flux-1-dev/lora': + return SdVersion.Flux + elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-xl-v1-base/lora': + return SdVersion.SDXL + elif str(self.metadata.get('ss_base_model_version', '')).startswith('sdxl_'): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', '')) == 'True': + return SdVersion.SD2 + elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-v1/lora': + return SdVersion.SD1 + + return SdVersion.Unknown + def set_hash(self, v): self.hash = v self.shorthash = self.hash[0:12] diff --git a/extensions-builtin/sd_forge_lora/scripts/lora_script.py b/extensions-builtin/sd_forge_lora/scripts/lora_script.py index a8a26d51..50864882 100644 --- a/extensions-builtin/sd_forge_lora/scripts/lora_script.py +++ b/extensions-builtin/sd_forge_lora/scripts/lora_script.py @@ -28,7 +28,6 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), "lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), - "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), diff --git a/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py b/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py index ecda5c33..ebf34de9 100644 --- a/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py +++ b/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py @@ -158,9 +158,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) return ", ".join(sorted(res)) def create_extra_default_items_in_left_column(self): - - # this would be a lot better as gr.Radio but I can't make it work - self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + self.select_sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True) def create_editor(self): self.create_default_editor_elems() diff --git a/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py b/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py index 10e4e024..9a649b1b 100644 --- a/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py +++ b/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py @@ -48,10 +48,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) - negative_prompt = item["user_metadata"].get("negative text") - item["negative_prompt"] = quote_js("") - if negative_prompt: - item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') + negative_prompt = item["user_metadata"].get("negative text", "") + item["negative_prompt"] = quote_js(negative_prompt) + + # filter displayed loras by UI setting + sd_version = item["user_metadata"].get("sd version") + if sd_version in network.SdVersion.__members__: + item["sd_version"] = sd_version + sd_version = network.SdVersion[sd_version] + else: + sd_version = lora_on_disk.sd_version # use heuristics + #sd_version = network.SdVersion.Unknown # avoid heuristics + + item["sd_version_str"] = str(sd_version) return item diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c5cced97..5a9c9bcd 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -50,6 +50,17 @@ function setupExtraNetworksForTab(tabname) { var applyFilter = function(force) { var searchTerm = search.value.toLowerCase(); + + // get UI preset + radioUI = gradioApp().querySelector('#forge_ui_preset'); + radioButtons = radioUI.getElementsByTagName('input'); + UIresult = 3; // default to 'all' + for (i = 0; i < radioButtons.length; i++) { + if (radioButtons[i].checked) { + UIresult = i; + } + } + gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) { @@ -60,6 +71,21 @@ function setupExtraNetworksForTab(tabname) { if (searchOnly && searchTerm.length < 4) { visible = false; } + + sdversion = elem.getAttribute('data-sort-sdversion'); + if (sdversion == null) ; + else if (sdversion == 'SdVersion.Unknown') ; + else if (UIresult == 3) ; // 'all' + else if (UIresult == 0) { // 'sd' + if (sdversion != 'SdVersion.SD1' && sdversion != 'SdVersion.SD2') visible = false; + } + else if (UIresult == 1) { // 'xl' + if (sdversion != 'SdVersion.SDXL') visible = false; + } + else if (UIresult == 2) { // 'flux' + if (sdversion != 'SdVersion.Flux') visible = false; + } + if (visible) { elem.classList.remove("hidden"); } else { @@ -70,6 +96,7 @@ function setupExtraNetworksForTab(tabname) { applySort(force); }; + var applySort = function(force) { var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card'); var parent = gradioApp().querySelector('#' + tabname_full + "_cards"); @@ -449,6 +476,17 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn pane.classList.toggle("extra-network-dirs-hidden", show); } +function clickLoraRefresh() { + var applyFunction = extraNetworksApplyFilter['txt2img_lora']; + if (applyFunction) { + applyFunction(true); + } + applyFunction = extraNetworksApplyFilter['img2img_lora']; + if (applyFunction) { + applyFunction(true); + } +} + function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Refresh Page button. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 395549bf..beb43e50 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -304,10 +304,12 @@ class ExtraNetworksPage: if search_only and shared.opts.extra_networks_hidden_models == "Never": return "" + item_sort_keys = item.get("sort_keys", {}) + item_sort_keys["SDversion"] = item.get("sd_version_str", "SdVersion.Unknown") sort_keys = " ".join( [ f'data-sort-{k}="{html.escape(str(v))}"' - for k, v in item.get("sort_keys", {}).items() + for k, v in item_sort_keys.items() ] ).strip() diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 1bddb4c0..77bc869a 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -62,7 +62,7 @@ def make_checkpoint_manager_ui(): if len(sd_models.checkpoints_list) > 0: shared.opts.set('sd_model_checkpoint', next(iter(sd_models.checkpoints_list.values())).name) - ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all']) + ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'], elem_id="forge_ui_preset") ckpt_list, vae_list = refresh_models() @@ -293,6 +293,7 @@ def forge_main_entry(): ] ui_forge_preset.change(on_preset_change, inputs=[ui_forge_preset], outputs=output_targets, queue=False, show_progress=False) + ui_forge_preset.change(js="clickLoraRefresh", fn=None, queue=False, show_progress=False) Context.root_block.load(on_preset_change, inputs=None, outputs=output_targets, queue=False, show_progress=False) refresh_model_loading_parameters()