From 027ab005b6ae2ffe7703f9ff36ab86d80d2ca8fd Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 06:11:24 -0800 Subject: [PATCH] Create forge_clip.py --- modules_forge/forge_clip.py | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 modules_forge/forge_clip.py diff --git a/modules_forge/forge_clip.py b/modules_forge/forge_clip.py new file mode 100644 index 00000000..e83209f6 --- /dev/null +++ b/modules_forge/forge_clip.py @@ -0,0 +1,39 @@ +from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords + + +class CLIP_SD_15_L(FrozenCLIPEmbedderWithCustomWords): + pass + + +class CLIP_SD_21_G(FrozenCLIPEmbedderWithCustomWords): + pass + + +class CLIP_SD_XL_L(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z + + +class CLIP_SD_XL_G(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z