From 27ad79053eda724f5f57cfbf893e87a1bcf44951 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 24 Dec 2023 14:31:29 -0700 Subject: [PATCH] Added SDXL support for clip vision embedder trainer --- toolkit/clip_vision_adapter.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py index 93ef4be2..4691bb0d 100644 --- a/toolkit/clip_vision_adapter.py +++ b/toolkit/clip_vision_adapter.py @@ -245,8 +245,21 @@ class ClipVisionAdapter(torch.nn.Module): if len(self.text_encoder_list) == 1: # add it to the text encoder self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) + elif len(self.text_encoder_list) == 2: + if self.text_encoder_list[0].config.hidden_size + self.text_encoder_list[1].config.hidden_size != \ + image_prompt_embeds.shape[2]: + raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes") + # sdxl variants + # image_prompt_embeds = 2048 + # te1 = 768 + # te2 = 1280 + te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.hidden_size] + te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.hidden_size:] + self.set_vec(te1_embeds[0], text_encoder_idx=0) + self.set_vec(te2_embeds[0], text_encoder_idx=1) else: - raise ValueError("Multiple text encoders not supported yet") + + raise ValueError("Unsupported number of text encoders") # just a place to put a breakpoint pass