diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index 10e9079f..ecbdc213 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -49,14 +49,17 @@ class CLIPEmbeddingForTextualInversion(torch.nn.Module): class ClassicTextProcessingEngine(torch.nn.Module): def __init__(self, text_encoder, tokenizer, chunk_length=75, - embedding_dir='./embeddings', embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", + embedding_dir=None, embedding_key='clip_l', embedding_expected_shape=768, emphasis_name="original", text_projection=False, minimal_clip_skip=1, clip_skip=1, return_pooled=False, final_layer_norm=True, callback_before_encode=None): super().__init__() self.embeddings = EmbeddingDatabase(tokenizer, embedding_expected_shape) - self.embeddings.add_embedding_dir(embedding_dir) - self.embeddings.load_textual_inversion_embeddings() + + if isinstance(embedding_dir, str): + self.embeddings.add_embedding_dir(embedding_dir) + self.embeddings.load_textual_inversion_embeddings() + self.embedding_key = embedding_key self.text_encoder = text_encoder