diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index 25c5e983..c690b5e7 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -1,71 +1,15 @@ -import open_clip.tokenizer import torch -from modules import sd_hijack_clip, devices +from modules import sd_hijack_clip from modules.shared import opts -tokenizer = open_clip.tokenizer._tokenizer - -class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): +class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) - self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] - self.id_start = tokenizer.encoder[""] - self.id_end = tokenizer.encoder[""] - self.id_pad = 0 - def tokenize(self, texts): - assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' - - tokenized = [tokenizer.encode(text) for text in texts] - - return tokenized - - def encode_with_transformers(self, tokens): - # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers - z = self.wrapped.encode_with_transformer(tokens) - - return z - - def encode_embedding_init_text(self, init_text, nvpt): - ids = tokenizer.encode(init_text) - ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) - - return embedded - - -class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) - - self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] - self.id_start = tokenizer.encoder[""] - self.id_end = tokenizer.encoder[""] - self.id_pad = 0 - - def tokenize(self, texts): - assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' - - tokenized = [tokenizer.encode(text) for text in texts] - - return tokenized - - def encode_with_transformers(self, tokens): - d = self.wrapped.encode_with_transformer(tokens) - z = d[self.wrapped.layer] - - pooled = d.get("pooled") - if pooled is not None: - z.pooled = pooled - - return z - - def encode_embedding_init_text(self, init_text, nvpt): - ids = tokenizer.encode(init_text) - ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) - - return embedded + a = 0 diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index c0fa0e50..cf3fdf17 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -161,17 +161,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ - if typename == 'FrozenOpenCLIPEmbedder': - embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer - embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer - model_embeddings = embedder.transformer.text_model.embeddings - model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( - model_embeddings.token_embedding, sd_hijack.model_hijack) - embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, sd_hijack.model_hijack) - - conditioner.embedders[i] = embedder - text_cond_models.append(embedder) - elif typename == 'FrozenCLIPEmbedder': + if typename == 'FrozenCLIPEmbedder': # SDXL Clip L embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer model_embeddings = embedder.transformer.text_model.embeddings @@ -181,7 +171,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): conditioner.embedders[i] = embedder text_cond_models.append(embedder) - elif typename == 'FrozenOpenCLIPEmbedder2': + elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer model_embeddings = embedder.transformer.text_model.embeddings @@ -196,7 +186,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.cond_stage_model = text_cond_models[0] else: sd_model.cond_stage_model = conditioner - elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': + elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': # SD15 Clip sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_l.transformer model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings @@ -204,7 +194,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings.token_embedding, sd_hijack.model_hijack) sd_model.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model, sd_hijack.model_hijack) - elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': + elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_g.transformer model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings