mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 04:13:57 +00:00
Added SDXL support for clip vision embedder trainer
This commit is contained in:
@@ -245,8 +245,21 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
if len(self.text_encoder_list) == 1:
|
||||
# add it to the text encoder
|
||||
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
|
||||
elif len(self.text_encoder_list) == 2:
|
||||
if self.text_encoder_list[0].config.hidden_size + self.text_encoder_list[1].config.hidden_size != \
|
||||
image_prompt_embeds.shape[2]:
|
||||
raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes")
|
||||
# sdxl variants
|
||||
# image_prompt_embeds = 2048
|
||||
# te1 = 768
|
||||
# te2 = 1280
|
||||
te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.hidden_size]
|
||||
te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.hidden_size:]
|
||||
self.set_vec(te1_embeds[0], text_encoder_idx=0)
|
||||
self.set_vec(te2_embeds[0], text_encoder_idx=1)
|
||||
else:
|
||||
raise ValueError("Multiple text encoders not supported yet")
|
||||
|
||||
raise ValueError("Unsupported number of text encoders")
|
||||
# just a place to put a breakpoint
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user