diff --git a/extensions-builtin/sd_forge_lora/network.py b/extensions-builtin/sd_forge_lora/network.py index 63de6562..ec6bf663 100644 --- a/extensions-builtin/sd_forge_lora/network.py +++ b/extensions-builtin/sd_forge_lora/network.py @@ -1,10 +1,18 @@ import os +import enum from modules import sd_models, cache, errors, hashes, shared metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 +# SD3 = 5 + Flux = 6 class NetworkOnDisk: def __init__(self, name, filename): @@ -41,6 +49,24 @@ class NetworkOnDisk: '' ) + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('modelspec.implementation', '')) == 'https://github.com/black-forest-labs/flux': + return SdVersion.Flux + elif str(self.metadata.get('modelspec.architecture', '')) == 'flux-1-dev/lora': + return SdVersion.Flux + elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-xl-v1-base/lora': + return SdVersion.SDXL + elif str(self.metadata.get('ss_base_model_version', '')).startswith('sdxl_'): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', '')) == 'True': + return SdVersion.SD2 + elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-v1/lora': + return SdVersion.SD1 + + return SdVersion.Unknown + def set_hash(self, v): self.hash = v self.shorthash = self.hash[0:12] diff --git a/extensions-builtin/sd_forge_lora/scripts/lora_script.py b/extensions-builtin/sd_forge_lora/scripts/lora_script.py index a8a26d51..50864882 100644 --- a/extensions-builtin/sd_forge_lora/scripts/lora_script.py +++ b/extensions-builtin/sd_forge_lora/scripts/lora_script.py @@ -28,7 +28,6 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), "lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), - "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), diff --git a/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py b/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py index ecda5c33..ebf34de9 100644 --- a/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py +++ b/extensions-builtin/sd_forge_lora/ui_edit_user_metadata.py @@ -158,9 +158,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) return ", ".join(sorted(res)) def create_extra_default_items_in_left_column(self): - - # this would be a lot better as gr.Radio but I can't make it work - self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + self.select_sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True) def create_editor(self): self.create_default_editor_elems() diff --git a/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py b/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py index 10e4e024..9a649b1b 100644 --- a/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py +++ b/extensions-builtin/sd_forge_lora/ui_extra_networks_lora.py @@ -48,10 +48,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) - negative_prompt = item["user_metadata"].get("negative text") - item["negative_prompt"] = quote_js("") - if negative_prompt: - item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') + negative_prompt = item["user_metadata"].get("negative text", "") + item["negative_prompt"] = quote_js(negative_prompt) + + # filter displayed loras by UI setting + sd_version = item["user_metadata"].get("sd version") + if sd_version in network.SdVersion.__members__: + item["sd_version"] = sd_version + sd_version = network.SdVersion[sd_version] + else: + sd_version = lora_on_disk.sd_version # use heuristics + #sd_version = network.SdVersion.Unknown # avoid heuristics + + item["sd_version_str"] = str(sd_version) return item diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c5cced97..5a9c9bcd 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -50,6 +50,17 @@ function setupExtraNetworksForTab(tabname) { var applyFilter = function(force) { var searchTerm = search.value.toLowerCase(); + + // get UI preset + radioUI = gradioApp().querySelector('#forge_ui_preset'); + radioButtons = radioUI.getElementsByTagName('input'); + UIresult = 3; // default to 'all' + for (i = 0; i < radioButtons.length; i++) { + if (radioButtons[i].checked) { + UIresult = i; + } + } + gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) { @@ -60,6 +71,21 @@ function setupExtraNetworksForTab(tabname) { if (searchOnly && searchTerm.length < 4) { visible = false; } + + sdversion = elem.getAttribute('data-sort-sdversion'); + if (sdversion == null) ; + else if (sdversion == 'SdVersion.Unknown') ; + else if (UIresult == 3) ; // 'all' + else if (UIresult == 0) { // 'sd' + if (sdversion != 'SdVersion.SD1' && sdversion != 'SdVersion.SD2') visible = false; + } + else if (UIresult == 1) { // 'xl' + if (sdversion != 'SdVersion.SDXL') visible = false; + } + else if (UIresult == 2) { // 'flux' + if (sdversion != 'SdVersion.Flux') visible = false; + } + if (visible) { elem.classList.remove("hidden"); } else { @@ -70,6 +96,7 @@ function setupExtraNetworksForTab(tabname) { applySort(force); }; + var applySort = function(force) { var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card'); var parent = gradioApp().querySelector('#' + tabname_full + "_cards"); @@ -449,6 +476,17 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn pane.classList.toggle("extra-network-dirs-hidden", show); } +function clickLoraRefresh() { + var applyFunction = extraNetworksApplyFilter['txt2img_lora']; + if (applyFunction) { + applyFunction(true); + } + applyFunction = extraNetworksApplyFilter['img2img_lora']; + if (applyFunction) { + applyFunction(true); + } +} + function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Refresh Page button. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 395549bf..beb43e50 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -304,10 +304,12 @@ class ExtraNetworksPage: if search_only and shared.opts.extra_networks_hidden_models == "Never": return "" + item_sort_keys = item.get("sort_keys", {}) + item_sort_keys["SDversion"] = item.get("sd_version_str", "SdVersion.Unknown") sort_keys = " ".join( [ f'data-sort-{k}="{html.escape(str(v))}"' - for k, v in item.get("sort_keys", {}).items() + for k, v in item_sort_keys.items() ] ).strip() diff --git a/modules_forge/main_entry.py b/modules_forge/main_entry.py index 1bddb4c0..77bc869a 100644 --- a/modules_forge/main_entry.py +++ b/modules_forge/main_entry.py @@ -62,7 +62,7 @@ def make_checkpoint_manager_ui(): if len(sd_models.checkpoints_list) > 0: shared.opts.set('sd_model_checkpoint', next(iter(sd_models.checkpoints_list.values())).name) - ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all']) + ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'], elem_id="forge_ui_preset") ckpt_list, vae_list = refresh_models() @@ -293,6 +293,7 @@ def forge_main_entry(): ] ui_forge_preset.change(on_preset_change, inputs=[ui_forge_preset], outputs=output_targets, queue=False, show_progress=False) + ui_forge_preset.change(js="clickLoraRefresh", fn=None, queue=False, show_progress=False) Context.root_block.load(on_preset_change, inputs=None, outputs=output_targets, queue=False, show_progress=False) refresh_model_loading_parameters()