Add a projection layer on vision direct when doing image embeds

This commit is contained in:
Jaret Burkett
2024-10-20 10:48:23 -06:00
parent dd931757cd
commit e3ebd73610

View File

@@ -753,6 +753,9 @@ class VisionDirectAdapter(torch.nn.Module):
hidden_dim=hidden_dim,
output_dim=self.config.sparse_autoencoder_dim
)
if self.config.clip_layer == "image_embeds":
self.proj = nn.Linear(self.token_size, self.token_size)
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.config.train_scaler:
@@ -777,6 +780,7 @@ class VisionDirectAdapter(torch.nn.Module):
# if doing image_embeds, normalize here
if self.config.clip_layer == "image_embeds":
input = norm_layer(input)
input = self.proj(input)
if self.resampler is not None:
input = self.resampler(input)
if self.pool is not None: