diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py index 93ef4be2..4691bb0d 100644 --- a/toolkit/clip_vision_adapter.py +++ b/toolkit/clip_vision_adapter.py @@ -245,8 +245,21 @@ class ClipVisionAdapter(torch.nn.Module): if len(self.text_encoder_list) == 1: # add it to the text encoder self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) + elif len(self.text_encoder_list) == 2: + if self.text_encoder_list[0].config.hidden_size + self.text_encoder_list[1].config.hidden_size != \ + image_prompt_embeds.shape[2]: + raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes") + # sdxl variants + # image_prompt_embeds = 2048 + # te1 = 768 + # te2 = 1280 + te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.hidden_size] + te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.hidden_size:] + self.set_vec(te1_embeds[0], text_encoder_idx=0) + self.set_vec(te2_embeds[0], text_encoder_idx=1) else: - raise ValueError("Multiple text encoders not supported yet") + + raise ValueError("Unsupported number of text encoders") # just a place to put a breakpoint pass