diff --git a/modules_forge/forge_clip.py b/modules_forge/forge_clip.py new file mode 100644 index 00000000..e83209f6 --- /dev/null +++ b/modules_forge/forge_clip.py @@ -0,0 +1,39 @@ +from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords + + +class CLIP_SD_15_L(FrozenCLIPEmbedderWithCustomWords): + pass + + +class CLIP_SD_21_G(FrozenCLIPEmbedderWithCustomWords): + pass + + +class CLIP_SD_XL_L(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z + + +class CLIP_SD_XL_G(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z