mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Normalize the image embeddings on vd adapter forward
This commit is contained in:
@@ -37,7 +37,9 @@ class Norm(nn.Module):
|
||||
|
||||
# Normalize
|
||||
return self.target_std * (x - mean) / (std + self.eps) + self.target_mean
|
||||
|
||||
|
||||
|
||||
norm_layer = Norm()
|
||||
|
||||
class SparseAutoencoder(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||
@@ -769,8 +771,12 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
# block scaler keeps moving dtypes. make sure it is float32 here
|
||||
# todo remove this when we have a real solution
|
||||
|
||||
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
# if doing image_embeds, normalize here
|
||||
if self.config.clip_layer == "image_embeds":
|
||||
input = norm_layer(input)
|
||||
if self.resampler is not None:
|
||||
input = self.resampler(input)
|
||||
if self.pool is not None:
|
||||
|
||||
Reference in New Issue
Block a user