diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 4be85b3a..82e52092 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -345,6 +345,14 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): return embedded +class CLIP_SD_15_L(FrozenCLIPEmbedderWithCustomWords): + pass + + +class CLIP_SD_21_G(FrozenCLIPEmbedderWithCustomWords): + pass + + class CLIP_SD_XL_L(FrozenCLIPEmbedderWithCustomWords): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) @@ -358,3 +366,18 @@ class CLIP_SD_XL_L(FrozenCLIPEmbedderWithCustomWords): z = outputs.hidden_states[self.wrapped.layer_idx] return z + + +class CLIP_SD_XL_G(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z