mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
WIP on SAFE encoder. Work on fp16 training improvements. Various other tweaks and improvements
This commit is contained in:
@@ -21,19 +21,24 @@ class ILoRAProjModule(torch.nn.Module):
|
||||
|
||||
self.num_modules = num_modules
|
||||
self.num_dim = dim
|
||||
self.norm = torch.nn.LayerNorm(embeddings_dim)
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.LayerNorm(embeddings_dim),
|
||||
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
|
||||
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 2),
|
||||
torch.nn.LayerNorm(embeddings_dim * 2),
|
||||
|
||||
torch.nn.Linear(embeddings_dim * 2, embeddings_dim * 4),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(embeddings_dim * 4, num_modules * dim),
|
||||
torch.nn.LayerNorm(num_modules * dim),
|
||||
)
|
||||
# Initialize the last linear layer weights near zero
|
||||
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[2].bias)
|
||||
torch.nn.init.uniform_(self.proj[-2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[-2].bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = x.reshape(-1, self.num_modules, self.num_dim)
|
||||
return x
|
||||
@@ -71,6 +76,8 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
# reshape if needed
|
||||
if len(x.shape) == 3:
|
||||
scaler = scaler.unsqueeze(1)
|
||||
if len(x.shape) == 4:
|
||||
scaler = scaler.unsqueeze(-1).unsqueeze(-1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(x.shape)
|
||||
|
||||
@@ -20,11 +20,11 @@ class SAFEReducerBlock(nn.Module):
|
||||
|
||||
self.reducer = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(channels),
|
||||
activation(),
|
||||
nn.BatchNorm2d(channels),
|
||||
nn.AvgPool2d(kernel_size=2, stride=2),
|
||||
)
|
||||
self.residual_shrink = nn.AvgPool2d(kernel_size=2, stride=2)
|
||||
@@ -227,6 +227,7 @@ class SAFEVMConfig:
|
||||
self.reducer_channels = reducer_channels
|
||||
self.channels = channels
|
||||
self.downscale_factor = downscale_factor
|
||||
self.image_size = 224
|
||||
|
||||
self.hidden_size = num_vectors
|
||||
self.projection_dim = num_vectors
|
||||
@@ -242,7 +243,9 @@ class SAFEVMReturn:
|
||||
class SAFEVisionModel(SizeAgnosticFeatureEncoder):
|
||||
def __init__(self, **kwargs):
|
||||
self.config = SAFEVMConfig(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self.image_size = None
|
||||
# super().__init__(**kwargs)
|
||||
super(SAFEVisionModel, self).__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user