Work on embedding adapters

This commit is contained in:
Jaret Burkett
2024-03-11 15:18:42 -06:00
parent f415bac7b5
commit d87b49882c
3 changed files with 55 additions and 6 deletions

View File

@@ -216,7 +216,10 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
adapter_hidden_states = torch.cat([
self.unconditional_embeds,
adapter_hidden_states
])
], dim=0)
# if it is image embeds, we need to add a 1 dim at inx 1
if len(adapter_hidden_states.shape) == 2:
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
# conditional_batch_size = adapter_hidden_states.shape[0]
# conditional_query = query
@@ -268,7 +271,10 @@ class VisionDirectAdapter(torch.nn.Module):
self.sd_ref: weakref.ref = weakref.ref(sd)
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
self.token_size = vision_model.config.hidden_size
if adapter.config.clip_layer == "image_embeds":
self.token_size = vision_model.config.projection_dim
else:
self.token_size = vision_model.config.hidden_size
# init adapter modules
attn_procs = {}