restore lora version filtering (#1885)

Added Flux to lora types in extra networks UI, so user can set.
Loras versioned first by user-set type, if any. Falls back to heuristics - these are much more reliable than the removed old A1111 tests and in case of no match default to Unknown (always displayed).
Filtering is done based on UI setting. 'all' setting does not filter. Filters lora lists on change.
Removed unused 'lora_hide_unknown_for_versions' setting.
This commit is contained in:
DenOfEquity
2024-09-23 14:53:58 +01:00
committed by GitHub
parent 95b54a27f1
commit a82d5d177c
7 changed files with 83 additions and 10 deletions

View File

@@ -1,10 +1,18 @@
import os
import enum
from modules import sd_models, cache, errors, hashes, shared
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class SdVersion(enum.Enum):
Unknown = 1
SD1 = 2
SD2 = 3
SDXL = 4
# SD3 = 5
Flux = 6
class NetworkOnDisk:
def __init__(self, name, filename):
@@ -41,6 +49,24 @@ class NetworkOnDisk:
''
)
self.sd_version = self.detect_version()
def detect_version(self):
if str(self.metadata.get('modelspec.implementation', '')) == 'https://github.com/black-forest-labs/flux':
return SdVersion.Flux
elif str(self.metadata.get('modelspec.architecture', '')) == 'flux-1-dev/lora':
return SdVersion.Flux
elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-xl-v1-base/lora':
return SdVersion.SDXL
elif str(self.metadata.get('ss_base_model_version', '')).startswith('sdxl_'):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', '')) == 'True':
return SdVersion.SD2
elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-v1/lora':
return SdVersion.SD1
return SdVersion.Unknown
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]

View File

@@ -28,7 +28,6 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
"lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),

View File

@@ -158,9 +158,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
return ", ".join(sorted(res))
def create_extra_default_items_in_left_column(self):
# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
self.select_sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True)
def create_editor(self):
self.create_default_editor_elems()

View File

@@ -48,10 +48,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
negative_prompt = item["user_metadata"].get("negative text")
item["negative_prompt"] = quote_js("")
if negative_prompt:
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
negative_prompt = item["user_metadata"].get("negative text", "")
item["negative_prompt"] = quote_js(negative_prompt)
# filter displayed loras by UI setting
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
sd_version = network.SdVersion[sd_version]
else:
sd_version = lora_on_disk.sd_version # use heuristics
#sd_version = network.SdVersion.Unknown # avoid heuristics
item["sd_version_str"] = str(sd_version)
return item

View File

@@ -50,6 +50,17 @@ function setupExtraNetworksForTab(tabname) {
var applyFilter = function(force) {
var searchTerm = search.value.toLowerCase();
// get UI preset
radioUI = gradioApp().querySelector('#forge_ui_preset');
radioButtons = radioUI.getElementsByTagName('input');
UIresult = 3; // default to 'all'
for (i = 0; i < radioButtons.length; i++) {
if (radioButtons[i].checked) {
UIresult = i;
}
}
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
var searchOnly = elem.querySelector('.search_only');
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) {
@@ -60,6 +71,21 @@ function setupExtraNetworksForTab(tabname) {
if (searchOnly && searchTerm.length < 4) {
visible = false;
}
sdversion = elem.getAttribute('data-sort-sdversion');
if (sdversion == null) ;
else if (sdversion == 'SdVersion.Unknown') ;
else if (UIresult == 3) ; // 'all'
else if (UIresult == 0) { // 'sd'
if (sdversion != 'SdVersion.SD1' && sdversion != 'SdVersion.SD2') visible = false;
}
else if (UIresult == 1) { // 'xl'
if (sdversion != 'SdVersion.SDXL') visible = false;
}
else if (UIresult == 2) { // 'flux'
if (sdversion != 'SdVersion.Flux') visible = false;
}
if (visible) {
elem.classList.remove("hidden");
} else {
@@ -70,6 +96,7 @@ function setupExtraNetworksForTab(tabname) {
applySort(force);
};
var applySort = function(force) {
var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card');
var parent = gradioApp().querySelector('#' + tabname_full + "_cards");
@@ -449,6 +476,17 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn
pane.classList.toggle("extra-network-dirs-hidden", show);
}
function clickLoraRefresh() {
var applyFunction = extraNetworksApplyFilter['txt2img_lora'];
if (applyFunction) {
applyFunction(true);
}
applyFunction = extraNetworksApplyFilter['img2img_lora'];
if (applyFunction) {
applyFunction(true);
}
}
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {
/**
* Handles `onclick` events for the Refresh Page button.

View File

@@ -304,10 +304,12 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never":
return ""
item_sort_keys = item.get("sort_keys", {})
item_sort_keys["SDversion"] = item.get("sd_version_str", "SdVersion.Unknown")
sort_keys = " ".join(
[
f'data-sort-{k}="{html.escape(str(v))}"'
for k, v in item.get("sort_keys", {}).items()
for k, v in item_sort_keys.items()
]
).strip()

View File

@@ -62,7 +62,7 @@ def make_checkpoint_manager_ui():
if len(sd_models.checkpoints_list) > 0:
shared.opts.set('sd_model_checkpoint', next(iter(sd_models.checkpoints_list.values())).name)
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'])
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'], elem_id="forge_ui_preset")
ckpt_list, vae_list = refresh_models()
@@ -293,6 +293,7 @@ def forge_main_entry():
]
ui_forge_preset.change(on_preset_change, inputs=[ui_forge_preset], outputs=output_targets, queue=False, show_progress=False)
ui_forge_preset.change(js="clickLoraRefresh", fn=None, queue=False, show_progress=False)
Context.root_block.load(on_preset_change, inputs=None, outputs=output_targets, queue=False, show_progress=False)
refresh_model_loading_parameters()