mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-13 08:59:51 +00:00
restore lora version filtering (#1885)
Added Flux to lora types in extra networks UI, so user can set. Loras versioned first by user-set type, if any. Falls back to heuristics - these are much more reliable than the removed old A1111 tests and in case of no match default to Unknown (always displayed). Filtering is done based on UI setting. 'all' setting does not filter. Filters lora lists on change. Removed unused 'lora_hide_unknown_for_versions' setting.
This commit is contained in:
@@ -1,10 +1,18 @@
|
||||
import os
|
||||
import enum
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
|
||||
|
||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||
|
||||
class SdVersion(enum.Enum):
|
||||
Unknown = 1
|
||||
SD1 = 2
|
||||
SD2 = 3
|
||||
SDXL = 4
|
||||
# SD3 = 5
|
||||
Flux = 6
|
||||
|
||||
class NetworkOnDisk:
|
||||
def __init__(self, name, filename):
|
||||
@@ -41,6 +49,24 @@ class NetworkOnDisk:
|
||||
''
|
||||
)
|
||||
|
||||
self.sd_version = self.detect_version()
|
||||
|
||||
def detect_version(self):
|
||||
if str(self.metadata.get('modelspec.implementation', '')) == 'https://github.com/black-forest-labs/flux':
|
||||
return SdVersion.Flux
|
||||
elif str(self.metadata.get('modelspec.architecture', '')) == 'flux-1-dev/lora':
|
||||
return SdVersion.Flux
|
||||
elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-xl-v1-base/lora':
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_base_model_version', '')).startswith('sdxl_'):
|
||||
return SdVersion.SDXL
|
||||
elif str(self.metadata.get('ss_v2', '')) == 'True':
|
||||
return SdVersion.SD2
|
||||
elif str(self.metadata.get('modelspec.architecture', '')) == 'stable-diffusion-v1/lora':
|
||||
return SdVersion.SD1
|
||||
|
||||
return SdVersion.Unknown
|
||||
|
||||
def set_hash(self, v):
|
||||
self.hash = v
|
||||
self.shorthash = self.hash[0:12]
|
||||
|
||||
@@ -28,7 +28,6 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
|
||||
"lora_filter_disabled": shared.OptionInfo(True, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
||||
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
||||
|
||||
@@ -158,9 +158,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
return ", ".join(sorted(res))
|
||||
|
||||
def create_extra_default_items_in_left_column(self):
|
||||
|
||||
# this would be a lot better as gr.Radio but I can't make it work
|
||||
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
|
||||
self.select_sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True)
|
||||
|
||||
def create_editor(self):
|
||||
self.create_default_editor_elems()
|
||||
|
||||
@@ -48,10 +48,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
if activation_text:
|
||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||
|
||||
negative_prompt = item["user_metadata"].get("negative text")
|
||||
item["negative_prompt"] = quote_js("")
|
||||
if negative_prompt:
|
||||
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
||||
negative_prompt = item["user_metadata"].get("negative text", "")
|
||||
item["negative_prompt"] = quote_js(negative_prompt)
|
||||
|
||||
# filter displayed loras by UI setting
|
||||
sd_version = item["user_metadata"].get("sd version")
|
||||
if sd_version in network.SdVersion.__members__:
|
||||
item["sd_version"] = sd_version
|
||||
sd_version = network.SdVersion[sd_version]
|
||||
else:
|
||||
sd_version = lora_on_disk.sd_version # use heuristics
|
||||
#sd_version = network.SdVersion.Unknown # avoid heuristics
|
||||
|
||||
item["sd_version_str"] = str(sd_version)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
@@ -50,6 +50,17 @@ function setupExtraNetworksForTab(tabname) {
|
||||
|
||||
var applyFilter = function(force) {
|
||||
var searchTerm = search.value.toLowerCase();
|
||||
|
||||
// get UI preset
|
||||
radioUI = gradioApp().querySelector('#forge_ui_preset');
|
||||
radioButtons = radioUI.getElementsByTagName('input');
|
||||
UIresult = 3; // default to 'all'
|
||||
for (i = 0; i < radioButtons.length; i++) {
|
||||
if (radioButtons[i].checked) {
|
||||
UIresult = i;
|
||||
}
|
||||
}
|
||||
|
||||
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
||||
var searchOnly = elem.querySelector('.search_only');
|
||||
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) {
|
||||
@@ -60,6 +71,21 @@ function setupExtraNetworksForTab(tabname) {
|
||||
if (searchOnly && searchTerm.length < 4) {
|
||||
visible = false;
|
||||
}
|
||||
|
||||
sdversion = elem.getAttribute('data-sort-sdversion');
|
||||
if (sdversion == null) ;
|
||||
else if (sdversion == 'SdVersion.Unknown') ;
|
||||
else if (UIresult == 3) ; // 'all'
|
||||
else if (UIresult == 0) { // 'sd'
|
||||
if (sdversion != 'SdVersion.SD1' && sdversion != 'SdVersion.SD2') visible = false;
|
||||
}
|
||||
else if (UIresult == 1) { // 'xl'
|
||||
if (sdversion != 'SdVersion.SDXL') visible = false;
|
||||
}
|
||||
else if (UIresult == 2) { // 'flux'
|
||||
if (sdversion != 'SdVersion.Flux') visible = false;
|
||||
}
|
||||
|
||||
if (visible) {
|
||||
elem.classList.remove("hidden");
|
||||
} else {
|
||||
@@ -70,6 +96,7 @@ function setupExtraNetworksForTab(tabname) {
|
||||
applySort(force);
|
||||
};
|
||||
|
||||
|
||||
var applySort = function(force) {
|
||||
var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card');
|
||||
var parent = gradioApp().querySelector('#' + tabname_full + "_cards");
|
||||
@@ -449,6 +476,17 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn
|
||||
pane.classList.toggle("extra-network-dirs-hidden", show);
|
||||
}
|
||||
|
||||
function clickLoraRefresh() {
|
||||
var applyFunction = extraNetworksApplyFilter['txt2img_lora'];
|
||||
if (applyFunction) {
|
||||
applyFunction(true);
|
||||
}
|
||||
applyFunction = extraNetworksApplyFilter['img2img_lora'];
|
||||
if (applyFunction) {
|
||||
applyFunction(true);
|
||||
}
|
||||
}
|
||||
|
||||
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Handles `onclick` events for the Refresh Page button.
|
||||
|
||||
@@ -304,10 +304,12 @@ class ExtraNetworksPage:
|
||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||
return ""
|
||||
|
||||
item_sort_keys = item.get("sort_keys", {})
|
||||
item_sort_keys["SDversion"] = item.get("sd_version_str", "SdVersion.Unknown")
|
||||
sort_keys = " ".join(
|
||||
[
|
||||
f'data-sort-{k}="{html.escape(str(v))}"'
|
||||
for k, v in item.get("sort_keys", {}).items()
|
||||
for k, v in item_sort_keys.items()
|
||||
]
|
||||
).strip()
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def make_checkpoint_manager_ui():
|
||||
if len(sd_models.checkpoints_list) > 0:
|
||||
shared.opts.set('sd_model_checkpoint', next(iter(sd_models.checkpoints_list.values())).name)
|
||||
|
||||
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'])
|
||||
ui_forge_preset = gr.Radio(label="UI", value=lambda: shared.opts.forge_preset, choices=['sd', 'xl', 'flux', 'all'], elem_id="forge_ui_preset")
|
||||
|
||||
ckpt_list, vae_list = refresh_models()
|
||||
|
||||
@@ -293,6 +293,7 @@ def forge_main_entry():
|
||||
]
|
||||
|
||||
ui_forge_preset.change(on_preset_change, inputs=[ui_forge_preset], outputs=output_targets, queue=False, show_progress=False)
|
||||
ui_forge_preset.change(js="clickLoraRefresh", fn=None, queue=False, show_progress=False)
|
||||
Context.root_block.load(on_preset_change, inputs=None, outputs=output_targets, queue=False, show_progress=False)
|
||||
|
||||
refresh_model_loading_parameters()
|
||||
|
||||
Reference in New Issue
Block a user