From e3ebd7361012ab33dd68961df4761aac8e2692ee Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 20 Oct 2024 10:48:23 -0600 Subject: [PATCH] Add a projection layer on vision direct when doing image embeds --- toolkit/models/vd_adapter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index c40f149b..946a1e10 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -753,6 +753,9 @@ class VisionDirectAdapter(torch.nn.Module): hidden_dim=hidden_dim, output_dim=self.config.sparse_autoencoder_dim ) + + if self.config.clip_layer == "image_embeds": + self.proj = nn.Linear(self.token_size, self.token_size) def state_dict(self, destination=None, prefix='', keep_vars=False): if self.config.train_scaler: @@ -777,6 +780,7 @@ class VisionDirectAdapter(torch.nn.Module): # if doing image_embeds, normalize here if self.config.clip_layer == "image_embeds": input = norm_layer(input) + input = self.proj(input) if self.resampler is not None: input = self.resampler(input) if self.pool is not None: