WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements

This commit is contained in:
Jaret Burkett
2024-05-27 10:50:24 -06:00
parent 68b7e159bc
commit 833c833f28
9 changed files with 127 additions and 49 deletions

View File

@@ -21,19 +21,24 @@ class ILoRAProjModule(torch.nn.Module):
self.num_modules = num_modules
self.num_dim = dim
self.norm = torch.nn.LayerNorm(embeddings_dim)
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(embeddings_dim),
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2),
torch.nn.LayerNorm(embeddings_dim * 2),
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 4, num_modules * dim),
torch.nn.LayerNorm(num_modules * dim),
)
# Initialize the last linear layer weights near zero
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[2].bias)
torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[-2].bias)
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
x = x.reshape(-1, self.num_modules, self.num_dim)
return x
@@ -71,6 +76,8 @@ class InstantLoRAMidModule(torch.nn.Module):
# reshape if needed
if len(x.shape) == 3:
scaler = scaler.unsqueeze(1)
if len(x.shape) == 4:
scaler = scaler.unsqueeze(-1).unsqueeze(-1)
except Exception as e:
print(e)
print(x.shape)

View File

@@ -20,11 +20,11 @@ class SAFEReducerBlock(nn.Module):
self.reducer = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
activation(),
nn.BatchNorm2d(channels),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.BatchNorm2d(channels),
activation(),
nn.BatchNorm2d(channels),
nn.AvgPool2d(kernel_size=2, stride=2),
)
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
@@ -227,6 +227,7 @@ class SAFEVMConfig:
self.reducer_channels = reducer_channels
self.channels = channels
self.downscale_factor = downscale_factor
self.image_size = 224
self.hidden_size = num_vectors
self.projection_dim = num_vectors
@@ -242,7 +243,9 @@ class SAFEVMReturn:
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
def __init__(self, **kwargs):
self.config = SAFEVMConfig(**kwargs)
super().__init__(**kwargs)
self.image_size = None
# super().__init__(**kwargs)
super(SAFEVisionModel, self).__init__(**kwargs)
@classmethod
def from_pretrained(cls, *args, **kwargs):