mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-12 22:19:48 +00:00
Add a projection layer on vision direct when doing image embeds
This commit is contained in:
@@ -753,6 +753,9 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=self.config.sparse_autoencoder_dim
|
||||
)
|
||||
|
||||
if self.config.clip_layer == "image_embeds":
|
||||
self.proj = nn.Linear(self.token_size, self.token_size)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
if self.config.train_scaler:
|
||||
@@ -777,6 +780,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
# if doing image_embeds, normalize here
|
||||
if self.config.clip_layer == "image_embeds":
|
||||
input = norm_layer(input)
|
||||
input = self.proj(input)
|
||||
if self.resampler is not None:
|
||||
input = self.resampler(input)
|
||||
if self.pool is not None:
|
||||
|
||||
Reference in New Issue
Block a user