REmove layers from direct vision resampler

This commit is contained in:
Jaret Burkett
2024-09-24 15:08:29 -06:00
parent 10817696fb
commit 6b4034122f

View File

@@ -27,8 +27,6 @@ class MLPR(nn.Module): # MLP with reshaping
in_channels,
out_dim,
out_channels,
hidden_dim,
hidden_channels,
use_residual=True
):
super().__init__()
@@ -37,24 +35,16 @@ class MLPR(nn.Module): # MLP with reshaping
# dont normalize if using conv
self.layer_norm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.conv1 = nn.Conv1d(in_channels, hidden_channels, 1)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
self.use_residual = use_residual
self.fc1 = nn.Linear(in_dim, out_dim)
self.act_fn = nn.GELU()
self.conv1 = nn.Conv1d(in_channels, out_channels, 1)
def forward(self, x):
residual = x
x = self.layer_norm(x)
x = self.fc1(x)
x = self.conv1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.conv2(x)
if self.use_residual:
x = x + residual
x = self.conv1(x)
return x
class AttnProcessor2_0(torch.nn.Module):
@@ -666,9 +656,6 @@ class VisionDirectAdapter(torch.nn.Module):
in_channels=max_seq_len,
out_dim=self.token_size,
out_channels=self.config.num_tokens,
hidden_dim=self.token_size,
hidden_channels=max_seq_len,
use_residual=False
)
def state_dict(self, destination=None, prefix='', keep_vars=False):