Work on embedding adapters

This commit is contained in:
Jaret Burkett
2024-03-11 15:18:42 -06:00
parent f415bac7b5
commit d87b49882c
3 changed files with 55 additions and 6 deletions

View File

@@ -767,6 +767,10 @@ class CustomAdapter(torch.nn.Module):
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
# TODO should we always norm image embeds?
# get norm embeddings
l2_norm = torch.norm(clip_image_embeds, p=2)
clip_image_embeds = clip_image_embeds / l2_norm
if not is_training or not self.config.train_image_encoder:
clip_image_embeds = clip_image_embeds.detach()