diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index c40f149b..946a1e10 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -753,6 +753,9 @@ class VisionDirectAdapter(torch.nn.Module): hidden_dim=hidden_dim, output_dim=self.config.sparse_autoencoder_dim ) + + if self.config.clip_layer == "image_embeds": + self.proj = nn.Linear(self.token_size, self.token_size) def state_dict(self, destination=None, prefix='', keep_vars=False): if self.config.train_scaler: @@ -777,6 +780,7 @@ class VisionDirectAdapter(torch.nn.Module): # if doing image_embeds, normalize here if self.config.clip_layer == "image_embeds": input = norm_layer(input) + input = self.proj(input) if self.resampler is not None: input = self.resampler(input) if self.pool is not None: