From ce759ebd8c653a5ac61c15c1bdacb210aa37df9e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 12 Oct 2024 15:09:48 +0000 Subject: [PATCH] Normalize the image embeddings on vd adapter forward --- toolkit/models/vd_adapter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 4232bf86..c40f149b 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -37,7 +37,9 @@ class Norm(nn.Module): # Normalize return self.target_std * (x - mean) / (std + self.eps) + self.target_mean - + + +norm_layer = Norm() class SparseAutoencoder(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): @@ -769,8 +771,12 @@ class VisionDirectAdapter(torch.nn.Module): def forward(self, input): # block scaler keeps moving dtypes. make sure it is float32 here # todo remove this when we have a real solution + if self.block_scaler is not None and self.block_scaler.dtype != torch.float32: self.block_scaler.data = self.block_scaler.data.to(torch.float32) + # if doing image_embeds, normalize here + if self.config.clip_layer == "image_embeds": + input = norm_layer(input) if self.resampler is not None: input = self.resampler(input) if self.pool is not None: