diff --git a/modules_forge/forge_clip.py b/modules_forge/forge_clip.py index a12f4d20..3dc8584a 100644 --- a/modules_forge/forge_clip.py +++ b/modules_forge/forge_clip.py @@ -13,6 +13,10 @@ def move_clip_to_gpu(): return +def apply_clip_skip_to_transformer_outputs(x, last_layer, skip): + return x.hidden_states[last_layer + 1 - skip] + + class CLIP_SD_15_L(FrozenCLIPEmbedderWithCustomWords): def encode_with_transformers(self, tokens): move_clip_to_gpu() @@ -20,7 +24,7 @@ class CLIP_SD_15_L(FrozenCLIPEmbedderWithCustomWords): outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: - z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + z = apply_clip_skip_to_transformer_outputs(outputs, last_layer=-1, skip=opts.CLIP_stop_at_last_layers) z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state @@ -45,7 +49,10 @@ class CLIP_SD_21_H(FrozenCLIPEmbedderWithCustomWords): self.wrapped.transformer.text_model.embeddings.to(tokens.device) outputs = self.wrapped.transformer(tokens, output_hidden_states=self.wrapped.layer == "hidden") - if self.wrapped.layer == "last": + if opts.CLIP_stop_at_last_layers > 1: + z = apply_clip_skip_to_transformer_outputs(outputs, last_layer=self.wrapped.layer_idx, skip=opts.CLIP_stop_at_last_layers) + z = self.wrapped.transformer.text_model.final_layer_norm(z) + elif self.wrapped.layer == "last": z = outputs.last_hidden_state else: z = outputs.hidden_states[self.wrapped.layer_idx] @@ -62,7 +69,9 @@ class CLIP_SD_XL_L(FrozenCLIPEmbedderWithCustomWords): self.wrapped.transformer.text_model.embeddings.to(tokens.device) outputs = self.wrapped.transformer(tokens, output_hidden_states=self.wrapped.layer == "hidden") - if self.wrapped.layer == "last": + if opts.CLIP_stop_at_last_layers > 1: + z = apply_clip_skip_to_transformer_outputs(outputs, last_layer=self.wrapped.layer_idx, skip=opts.CLIP_stop_at_last_layers) + elif self.wrapped.layer == "last": z = outputs.last_hidden_state else: z = outputs.hidden_states[self.wrapped.layer_idx] @@ -86,7 +95,9 @@ class CLIP_SD_XL_G(FrozenCLIPEmbedderWithCustomWords): self.wrapped.transformer.text_model.embeddings.to(tokens.device) outputs = self.wrapped.transformer(tokens, output_hidden_states=self.wrapped.layer == "hidden") - if self.wrapped.layer == "last": + if opts.CLIP_stop_at_last_layers > 1: + z = apply_clip_skip_to_transformer_outputs(outputs, last_layer=self.wrapped.layer_idx, skip=opts.CLIP_stop_at_last_layers) + elif self.wrapped.layer == "last": z = outputs.last_hidden_state else: z = outputs.hidden_states[self.wrapped.layer_idx]