mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Work on embedding adapters
This commit is contained in:
@@ -216,7 +216,10 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
|
||||
adapter_hidden_states = torch.cat([
|
||||
self.unconditional_embeds,
|
||||
adapter_hidden_states
|
||||
])
|
||||
], dim=0)
|
||||
# if it is image embeds, we need to add a 1 dim at inx 1
|
||||
if len(adapter_hidden_states.shape) == 2:
|
||||
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
|
||||
# conditional_batch_size = adapter_hidden_states.shape[0]
|
||||
# conditional_query = query
|
||||
|
||||
@@ -268,7 +271,10 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
||||
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
if adapter.config.clip_layer == "image_embeds":
|
||||
self.token_size = vision_model.config.projection_dim
|
||||
else:
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
|
||||
Reference in New Issue
Block a user