Added support for pixart sigma loras

This commit is contained in:
Jaret Burkett
2024-06-16 11:56:30 -06:00
parent ada722c9e4
commit 5d47244c57
5 changed files with 35 additions and 562 deletions

View File

@@ -54,14 +54,19 @@ class LoRAGenerator(torch.nn.Module):
super().__init__()
self.input_size = input_size
self.num_heads = num_heads
self.simple = False
self.output_size = output_size
self.lin_in = nn.Linear(input_size, hidden_size)
self.mlp_blocks = nn.Sequential(*[
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
])
self.head = nn.Linear(hidden_size, head_size, bias=False)
if self.simple:
self.head = nn.Linear(input_size, head_size, bias=False)
else:
self.lin_in = nn.Linear(input_size, hidden_size)
self.mlp_blocks = nn.Sequential(*[
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
])
self.head = nn.Linear(hidden_size, head_size, bias=False)
self.norm = nn.LayerNorm(head_size)
if num_heads == 1:
@@ -90,8 +95,11 @@ class LoRAGenerator(torch.nn.Module):
if len(embedding.shape) == 2:
embedding = embedding.unsqueeze(1)
x = self.lin_in(embedding)
x = self.mlp_blocks(x)
x = embedding
if not self.simple:
x = self.lin_in(embedding)
x = self.mlp_blocks(x)
x = self.head(x)
x = self.norm(x)