mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
REmove layers from direct vision resampler
This commit is contained in:
@@ -27,8 +27,6 @@ class MLPR(nn.Module): # MLP with reshaping
|
||||
in_channels,
|
||||
out_dim,
|
||||
out_channels,
|
||||
hidden_dim,
|
||||
hidden_channels,
|
||||
use_residual=True
|
||||
):
|
||||
super().__init__()
|
||||
@@ -37,24 +35,16 @@ class MLPR(nn.Module): # MLP with reshaping
|
||||
# dont normalize if using conv
|
||||
self.layer_norm = nn.LayerNorm(in_dim)
|
||||
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.conv1 = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
self.use_residual = use_residual
|
||||
self.fc1 = nn.Linear(in_dim, out_dim)
|
||||
self.act_fn = nn.GELU()
|
||||
self.conv1 = nn.Conv1d(in_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.conv2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
class AttnProcessor2_0(torch.nn.Module):
|
||||
@@ -666,9 +656,6 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
in_channels=max_seq_len,
|
||||
out_dim=self.token_size,
|
||||
out_channels=self.config.num_tokens,
|
||||
hidden_dim=self.token_size,
|
||||
hidden_channels=max_seq_len,
|
||||
use_residual=False
|
||||
)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
|
||||
Reference in New Issue
Block a user