Normalize the image embeddings on vd adapter forward

This commit is contained in:
Jaret Burkett
2024-10-12 15:09:48 +00:00
parent 628a7923a3
commit ce759ebd8c

View File

@@ -37,7 +37,9 @@ class Norm(nn.Module):
# Normalize
return self.target_std * (x - mean) / (std + self.eps) + self.target_mean
norm_layer = Norm()
class SparseAutoencoder(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
@@ -769,8 +771,12 @@ class VisionDirectAdapter(torch.nn.Module):
def forward(self, input):
# block scaler keeps moving dtypes. make sure it is float32 here
# todo remove this when we have a real solution
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
# if doing image_embeds, normalize here
if self.config.clip_layer == "image_embeds":
input = norm_layer(input)
if self.resampler is not None:
input = self.resampler(input)
if self.pool is not None: