Vae/te preferences via cards (#1912)

Allows setting of preferred VAE and Text encoder(s) for checkpoints when selected via Checkpoint cards. No selection saved means no change to current toprow setting. 'Built in' option, if the only choice, means clear the toprow selection (therefore use vae/te built-in to checkpoint).
Also allows setting model type for checkpoints (SD1/SD2/SDXL/Flux/Unknown) (user set only, no attempt at autodetection), enabling filtering of the cards based on UI preset.
This commit is contained in:
DenOfEquity
2024-09-25 20:45:11 +01:00
committed by GitHub
parent c2d290e6c9
commit 7876862c43
5 changed files with 66 additions and 18 deletions

View File

@@ -362,6 +362,10 @@ function selectCheckpoint(name) {
desiredCheckpointName = name;
gradioApp().getElementById('change_checkpoint').click();
}
var desiredVAEName = 0;
function selectVAE(vae) {
desiredVAEName = vae;
}
function currentImg2imgSourceResolution(w, h, r) {
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] :is(img, canvas)');

View File

@@ -185,6 +185,15 @@ def list_models():
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
def match_checkpoint_to_name(name):
name = name.split(' [')[0]
for ckptname in checkpoints_list.values():
title = ckptname.title.split(' [')[0]
if (name in title) or (title in name):
return ckptname.short_title if shared.opts.sd_checkpoint_dropdown_use_short else ckptname.name.split(' [')[0]
return name
def get_closet_checkpoint_match(search_string):
if not search_string:

View File

@@ -214,6 +214,12 @@ class ExtraNetworksPage:
desc = metadata.get("description", None)
if desc is not None:
item["description"] = desc
vae = metadata.get("vae", None)
if vae is not None:
item["vae"] = vae
version = metadata.get("sd_version_str", None)
if version is not None:
item["sd_version_str"] = version
item["user_metadata"] = metadata
@@ -257,7 +263,7 @@ class ExtraNetworksPage:
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
onclick = item.get("onclick", None)
if onclick is None:
if onclick is None: # this path is 'Textual Inversion' and 'Lora'
# Don't quote prompt/neg_prompt since they are stored as js strings already.
onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});"
onclick = onclick_js_tpl.format(
@@ -269,6 +275,9 @@ class ExtraNetworksPage:
}
)
onclick = html.escape(onclick)
else: # this path is 'Checkpoints'
vae = item.get("vae", [])
onclick = html.escape(f"selectVAE({vae});") + onclick
btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]})
btn_metadata = ""

View File

@@ -1,42 +1,65 @@
import gradio as gr
from modules import ui_extra_networks_user_metadata, sd_vae, shared
from modules.ui_common import create_refresh_button
from modules.ui_components import ToolButton
from modules_forge import main_entry
refresh_symbol = '\U0001f504' # 🔄
class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
def __init__(self, ui, tabname, page):
super().__init__(ui, tabname, page)
self.select_vae = None
self.sd_version = 'Unknown'
def save_user_metadata(self, name, desc, notes, vae):
def save_user_metadata(self, name, desc, notes, vae, sd_version):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["notes"] = notes
user_metadata["vae"] = vae
user_metadata["sd_version_str"] = 'SdVersion.' + sd_version
self.write_user_metadata(name, user_metadata)
def update_vae(self, name):
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
sd_vae.reload_vae_weights()
def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)
vae = user_metadata.get('vae', None)
version = user_metadata.get('sd_version_str', '')
if version == '':
version = 'Unknown'
else:
version = version.replace('SdVersion.', '')
return [
*values[0:5],
user_metadata.get('vae', ''),
vae,
version,
]
def create_editor(self):
def create_editor(self): #happens before main_entry.modules_list is filled
modules_list = ['Built in']
if main_entry.module_list == {}:
_, modules = main_entry.refresh_models()
modules_list += list(modules)
else:
modules_list += list(main_entry.module_list.keys())
def refreshModules ():
return gr.update(choices=['Built in'] + list(main_entry.module_list.keys()))
self.create_default_editor_elems()
self.sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True)
with gr.Row():
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
self.select_vae = gr.Dropdown(choices=modules_list, value=None, label="Preferred VAE / Text encoder(s)", elem_id="checpoint_edit_user_metadata_preferred_vae", multiselect=True)
self.refresh = ToolButton(refresh_symbol)
self.refresh.click(fn=refreshModules, outputs=self.select_vae, show_progress='hidden')
self.edit_notes = gr.TextArea(label='Notes', lines=4)
@@ -49,6 +72,7 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
self.html_preview,
self.edit_notes,
self.select_vae,
self.sd_version,
]
self.button_edit\
@@ -59,8 +83,7 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
self.edit_description,
self.edit_notes,
self.select_vae,
self.sd_version,
]
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])

View File

@@ -324,15 +324,18 @@ class UiSettings:
show_progress=False,
)
def button_set_checkpoint_change(value, dummy):
return value.split(' [')[0], opts.dumpjson()
def button_set_checkpoint_change(model, vae, dummy):
if 'Built in' in vae:
vae.remove('Built in')
model = sd_models.match_checkpoint_to_name(model)
return model, vae, opts.dumpjson()
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
fn=button_set_checkpoint_change,
js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
inputs=[main_entry.ui_checkpoint, self.dummy_component],
outputs=[main_entry.ui_checkpoint, self.text_settings],
js="function(c, v, n){ var ckpt = desiredCheckpointName; var vae = desiredVAEName; if (vae == 0) vae = v; desiredCheckpointName = null; desiredVAEName = 0; return [ckpt, vae, null]; }",
inputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.dummy_component],
outputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.text_settings],
)
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]