mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 17:29:27 +00:00
Added support for pixart sigma loras
This commit is contained in:
@@ -54,14 +54,19 @@ class LoRAGenerator(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.num_heads = num_heads
|
||||
self.simple = False
|
||||
|
||||
self.output_size = output_size
|
||||
self.lin_in = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.mlp_blocks = nn.Sequential(*[
|
||||
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
|
||||
])
|
||||
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
||||
if self.simple:
|
||||
self.head = nn.Linear(input_size, head_size, bias=False)
|
||||
else:
|
||||
self.lin_in = nn.Linear(input_size, hidden_size)
|
||||
|
||||
self.mlp_blocks = nn.Sequential(*[
|
||||
MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers)
|
||||
])
|
||||
self.head = nn.Linear(hidden_size, head_size, bias=False)
|
||||
self.norm = nn.LayerNorm(head_size)
|
||||
|
||||
if num_heads == 1:
|
||||
@@ -90,8 +95,11 @@ class LoRAGenerator(torch.nn.Module):
|
||||
if len(embedding.shape) == 2:
|
||||
embedding = embedding.unsqueeze(1)
|
||||
|
||||
x = self.lin_in(embedding)
|
||||
x = self.mlp_blocks(x)
|
||||
x = embedding
|
||||
|
||||
if not self.simple:
|
||||
x = self.lin_in(embedding)
|
||||
x = self.mlp_blocks(x)
|
||||
x = self.head(x)
|
||||
x = self.norm(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user