Added SDXL support for clip vision embedder trainer

This commit is contained in:
Jaret Burkett
2023-12-24 14:31:29 -07:00
parent 05ae95ca89
commit 27ad79053e

View File

@@ -245,8 +245,21 @@ class ClipVisionAdapter(torch.nn.Module):
if len(self.text_encoder_list) == 1:
# add it to the text encoder
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
elif len(self.text_encoder_list) == 2:
if self.text_encoder_list[0].config.hidden_size + self.text_encoder_list[1].config.hidden_size != \
image_prompt_embeds.shape[2]:
raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes")
# sdxl variants
# image_prompt_embeds = 2048
# te1 = 768
# te2 = 1280
te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.hidden_size]
te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.hidden_size:]
self.set_vec(te1_embeds[0], text_encoder_idx=0)
self.set_vec(te2_embeds[0], text_encoder_idx=1)
else:
raise ValueError("Multiple text encoders not supported yet")
raise ValueError("Unsupported number of text encoders")
# just a place to put a breakpoint
pass