From 7c612a012badbcaad2e7f84e18d4aa58793a5da7 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sat, 3 Aug 2024 21:21:41 -0700 Subject: [PATCH] make Textual Inversion UI isolated so that we can replace it soon also removed the extremely annoying SD version filter --- modules/textual_inversion/textual_inversion.py | 6 ++++-- modules/ui_extra_networks_textual_inversion.py | 17 ++++++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index dc7833e9..4aa14fe4 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -209,7 +209,7 @@ class EmbeddingDatabase: errors.report(f"Error loading embedding {fn}", exc_info=True) continue - def load_textual_inversion_embeddings(self, force_reload=False): + def load_textual_inversion_embeddings(self, force_reload=False, sync_with_sd_model=True): if not force_reload: need_reload = False for embdir in self.embedding_dirs.values(): @@ -223,7 +223,9 @@ class EmbeddingDatabase: self.ids_lookup.clear() self.word_embeddings.clear() self.skipped_embeddings.clear() - self.expected_shape = self.get_expected_shape() + + if sync_with_sd_model: + self.expected_shape = self.get_expected_shape() for embdir in self.embedding_dirs.values(): self.load_from_dir(embdir) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index deb7cb87..0887eaf0 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,19 +1,26 @@ import os +import modules.textual_inversion.textual_inversion -from modules import ui_extra_networks, sd_hijack, shared +from modules.shared import cmd_opts +from modules import ui_extra_networks, shared from modules.ui_extra_networks import quote_js +embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() +embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) +embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) + + class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def __init__(self): super().__init__('Textual Inversion') self.allow_negative_prompt = True def refresh(self): - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + embedding_db.load_textual_inversion_embeddings(force_reload=True, sync_with_sd_model=False) def create_item(self, name, index=None, enable_filter=True): - embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + embedding = embedding_db.word_embeddings.get(name) if embedding is None: return @@ -35,11 +42,11 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def list_items(self): # instantiate a list to protect against concurrent modification - names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) + names = list(embedding_db.word_embeddings) for index, name in enumerate(names): item = self.create_item(name, index) if item is not None: yield item def allowed_directories_for_previews(self): - return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) + return list(embedding_db.embedding_dirs)