From 7876862c43e5c8562891a8dd4c57ed4e357f6c02 Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:45:11 +0100 Subject: [PATCH] Vae/te preferences via cards (#1912) Allows setting of preferred VAE and Text encoder(s) for checkpoints when selected via Checkpoint cards. No selection saved means no change to current toprow setting. 'Built in' option, if the only choice, means clear the toprow selection (therefore use vae/te built-in to checkpoint). Also allows setting model type for checkpoints (SD1/SD2/SDXL/Flux/Unknown) (user set only, no attempt at autodetection), enabling filtering of the cards based on UI preset. --- javascript/ui.js | 4 ++ modules/sd_models.py | 9 ++++ modules/ui_extra_networks.py | 11 ++++- ...xtra_networks_checkpoints_user_metadata.py | 47 ++++++++++++++----- modules/ui_settings.py | 13 +++-- 5 files changed, 66 insertions(+), 18 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index c617d407..d4b3dbf3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -362,6 +362,10 @@ function selectCheckpoint(name) { desiredCheckpointName = name; gradioApp().getElementById('change_checkpoint').click(); } +var desiredVAEName = 0; +function selectVAE(vae) { + desiredVAEName = vae; +} function currentImg2imgSourceResolution(w, h, r) { var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] :is(img, canvas)'); diff --git a/modules/sd_models.py b/modules/sd_models.py index c3085785..611ef764 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -185,6 +185,15 @@ def list_models(): re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") +def match_checkpoint_to_name(name): + name = name.split(' [')[0] + + for ckptname in checkpoints_list.values(): + title = ckptname.title.split(' [')[0] + if (name in title) or (title in name): + return ckptname.short_title if shared.opts.sd_checkpoint_dropdown_use_short else ckptname.name.split(' [')[0] + + return name def get_closet_checkpoint_match(search_string): if not search_string: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index beb43e50..88285bbe 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -214,6 +214,12 @@ class ExtraNetworksPage: desc = metadata.get("description", None) if desc is not None: item["description"] = desc + vae = metadata.get("vae", None) + if vae is not None: + item["vae"] = vae + version = metadata.get("sd_version_str", None) + if version is not None: + item["sd_version_str"] = version item["user_metadata"] = metadata @@ -257,7 +263,7 @@ class ExtraNetworksPage: background_image = f'' if preview else '' onclick = item.get("onclick", None) - if onclick is None: + if onclick is None: # this path is 'Textual Inversion' and 'Lora' # Don't quote prompt/neg_prompt since they are stored as js strings already. onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});" onclick = onclick_js_tpl.format( @@ -269,6 +275,9 @@ class ExtraNetworksPage: } ) onclick = html.escape(onclick) + else: # this path is 'Checkpoints' + vae = item.get("vae", []) + onclick = html.escape(f"selectVAE({vae});") + onclick btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) btn_metadata = "" diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py index 25df0a80..b7c66382 100644 --- a/modules/ui_extra_networks_checkpoints_user_metadata.py +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -1,42 +1,65 @@ import gradio as gr from modules import ui_extra_networks_user_metadata, sd_vae, shared -from modules.ui_common import create_refresh_button +from modules.ui_components import ToolButton +from modules_forge import main_entry +refresh_symbol = '\U0001f504' # 🔄 class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): def __init__(self, ui, tabname, page): super().__init__(ui, tabname, page) self.select_vae = None + self.sd_version = 'Unknown' - def save_user_metadata(self, name, desc, notes, vae): + def save_user_metadata(self, name, desc, notes, vae, sd_version): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["notes"] = notes user_metadata["vae"] = vae + user_metadata["sd_version_str"] = 'SdVersion.' + sd_version self.write_user_metadata(name, user_metadata) - def update_vae(self, name): - if name == shared.sd_model.sd_checkpoint_info.name_for_extra: - sd_vae.reload_vae_weights() - def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) values = super().put_values_into_components(name) + + vae = user_metadata.get('vae', None) + + version = user_metadata.get('sd_version_str', '') + if version == '': + version = 'Unknown' + else: + version = version.replace('SdVersion.', '') return [ *values[0:5], - user_metadata.get('vae', ''), + vae, + version, ] - def create_editor(self): + def create_editor(self): #happens before main_entry.modules_list is filled + modules_list = ['Built in'] + if main_entry.module_list == {}: + _, modules = main_entry.refresh_models() + modules_list += list(modules) + else: + modules_list += list(main_entry.module_list.keys()) + + def refreshModules (): + return gr.update(choices=['Built in'] + list(main_entry.module_list.keys())) + self.create_default_editor_elems() + self.sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True) + with gr.Row(): - self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") - create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + self.select_vae = gr.Dropdown(choices=modules_list, value=None, label="Preferred VAE / Text encoder(s)", elem_id="checpoint_edit_user_metadata_preferred_vae", multiselect=True) + self.refresh = ToolButton(refresh_symbol) + + self.refresh.click(fn=refreshModules, outputs=self.select_vae, show_progress='hidden') self.edit_notes = gr.TextArea(label='Notes', lines=4) @@ -49,6 +72,7 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.html_preview, self.edit_notes, self.select_vae, + self.sd_version, ] self.button_edit\ @@ -59,8 +83,7 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.edit_description, self.edit_notes, self.select_vae, + self.sd_version, ] self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) - self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) - diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 1c64607e..be61e5be 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -324,15 +324,18 @@ class UiSettings: show_progress=False, ) - def button_set_checkpoint_change(value, dummy): - return value.split(' [')[0], opts.dumpjson() + def button_set_checkpoint_change(model, vae, dummy): + if 'Built in' in vae: + vae.remove('Built in') + model = sd_models.match_checkpoint_to_name(model) + return model, vae, opts.dumpjson() button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( fn=button_set_checkpoint_change, - js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", - inputs=[main_entry.ui_checkpoint, self.dummy_component], - outputs=[main_entry.ui_checkpoint, self.text_settings], + js="function(c, v, n){ var ckpt = desiredCheckpointName; var vae = desiredVAEName; if (vae == 0) vae = v; desiredCheckpointName = null; desiredVAEName = 0; return [ckpt, vae, null]; }", + inputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.dummy_component], + outputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.text_settings], ) component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]