diff --git a/javascript/ui.js b/javascript/ui.js index c617d407..d4b3dbf3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -362,6 +362,10 @@ function selectCheckpoint(name) { desiredCheckpointName = name; gradioApp().getElementById('change_checkpoint').click(); } +var desiredVAEName = 0; +function selectVAE(vae) { + desiredVAEName = vae; +} function currentImg2imgSourceResolution(w, h, r) { var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] :is(img, canvas)'); diff --git a/modules/sd_models.py b/modules/sd_models.py index c3085785..611ef764 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -185,6 +185,15 @@ def list_models(): re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") +def match_checkpoint_to_name(name): + name = name.split(' [')[0] + + for ckptname in checkpoints_list.values(): + title = ckptname.title.split(' [')[0] + if (name in title) or (title in name): + return ckptname.short_title if shared.opts.sd_checkpoint_dropdown_use_short else ckptname.name.split(' [')[0] + + return name def get_closet_checkpoint_match(search_string): if not search_string: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index beb43e50..88285bbe 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -214,6 +214,12 @@ class ExtraNetworksPage: desc = metadata.get("description", None) if desc is not None: item["description"] = desc + vae = metadata.get("vae", None) + if vae is not None: + item["vae"] = vae + version = metadata.get("sd_version_str", None) + if version is not None: + item["sd_version_str"] = version item["user_metadata"] = metadata @@ -257,7 +263,7 @@ class ExtraNetworksPage: background_image = f'' if preview else '' onclick = item.get("onclick", None) - if onclick is None: + if onclick is None: # this path is 'Textual Inversion' and 'Lora' # Don't quote prompt/neg_prompt since they are stored as js strings already. onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});" onclick = onclick_js_tpl.format( @@ -269,6 +275,9 @@ class ExtraNetworksPage: } ) onclick = html.escape(onclick) + else: # this path is 'Checkpoints' + vae = item.get("vae", []) + onclick = html.escape(f"selectVAE({vae});") + onclick btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) btn_metadata = "" diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py index 25df0a80..b7c66382 100644 --- a/modules/ui_extra_networks_checkpoints_user_metadata.py +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -1,42 +1,65 @@ import gradio as gr from modules import ui_extra_networks_user_metadata, sd_vae, shared -from modules.ui_common import create_refresh_button +from modules.ui_components import ToolButton +from modules_forge import main_entry +refresh_symbol = '\U0001f504' # 🔄 class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): def __init__(self, ui, tabname, page): super().__init__(ui, tabname, page) self.select_vae = None + self.sd_version = 'Unknown' - def save_user_metadata(self, name, desc, notes, vae): + def save_user_metadata(self, name, desc, notes, vae, sd_version): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["notes"] = notes user_metadata["vae"] = vae + user_metadata["sd_version_str"] = 'SdVersion.' + sd_version self.write_user_metadata(name, user_metadata) - def update_vae(self, name): - if name == shared.sd_model.sd_checkpoint_info.name_for_extra: - sd_vae.reload_vae_weights() - def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) values = super().put_values_into_components(name) + + vae = user_metadata.get('vae', None) + + version = user_metadata.get('sd_version_str', '') + if version == '': + version = 'Unknown' + else: + version = version.replace('SdVersion.', '') return [ *values[0:5], - user_metadata.get('vae', ''), + vae, + version, ] - def create_editor(self): + def create_editor(self): #happens before main_entry.modules_list is filled + modules_list = ['Built in'] + if main_entry.module_list == {}: + _, modules = main_entry.refresh_models() + modules_list += list(modules) + else: + modules_list += list(main_entry.module_list.keys()) + + def refreshModules (): + return gr.update(choices=['Built in'] + list(main_entry.module_list.keys())) + self.create_default_editor_elems() + self.sd_version = gr.Radio(['SD1', 'SD2', 'SDXL', 'Flux', 'Unknown'], value='Unknown', label='Base model', interactive=True) + with gr.Row(): - self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") - create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + self.select_vae = gr.Dropdown(choices=modules_list, value=None, label="Preferred VAE / Text encoder(s)", elem_id="checpoint_edit_user_metadata_preferred_vae", multiselect=True) + self.refresh = ToolButton(refresh_symbol) + + self.refresh.click(fn=refreshModules, outputs=self.select_vae, show_progress='hidden') self.edit_notes = gr.TextArea(label='Notes', lines=4) @@ -49,6 +72,7 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.html_preview, self.edit_notes, self.select_vae, + self.sd_version, ] self.button_edit\ @@ -59,8 +83,7 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.edit_description, self.edit_notes, self.select_vae, + self.sd_version, ] self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) - self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) - diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 1c64607e..be61e5be 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -324,15 +324,18 @@ class UiSettings: show_progress=False, ) - def button_set_checkpoint_change(value, dummy): - return value.split(' [')[0], opts.dumpjson() + def button_set_checkpoint_change(model, vae, dummy): + if 'Built in' in vae: + vae.remove('Built in') + model = sd_models.match_checkpoint_to_name(model) + return model, vae, opts.dumpjson() button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( fn=button_set_checkpoint_change, - js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", - inputs=[main_entry.ui_checkpoint, self.dummy_component], - outputs=[main_entry.ui_checkpoint, self.text_settings], + js="function(c, v, n){ var ckpt = desiredCheckpointName; var vae = desiredVAEName; if (vae == 0) vae = v; desiredCheckpointName = null; desiredVAEName = 0; return [ckpt, vae, null]; }", + inputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.dummy_component], + outputs=[main_entry.ui_checkpoint, main_entry.ui_vae, self.text_settings], ) component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]