mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-26 09:19:20 +00:00
Some work on sd3 training. Not working
This commit is contained in:
@@ -48,7 +48,7 @@ class LoRAGenerator(torch.nn.Module):
|
||||
head_size: int = 512,
|
||||
num_mlp_layers: int = 1,
|
||||
output_size: int = 768,
|
||||
dropout: float = 0.5
|
||||
dropout: float = 0.0
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
@@ -131,8 +131,12 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.down_shape)
|
||||
# run a simple lenear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
@@ -158,8 +162,12 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.up_shape)
|
||||
# run a simple lenear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user