Some work on sd3 training. Not working

This commit is contained in:
Jaret Burkett
2024-06-13 12:19:16 -06:00
parent cb5d28cba9
commit bd10d2d668
12 changed files with 306 additions and 36 deletions

View File

@@ -48,7 +48,7 @@ class LoRAGenerator(torch.nn.Module):
head_size: int = 512,
num_mlp_layers: int = 1,
output_size: int = 768,
dropout: float = 0.5
dropout: float = 0.0
):
super().__init__()
self.input_size = input_size
@@ -131,8 +131,12 @@ class InstantLoRAMidModule(torch.nn.Module):
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.down_shape)
# run a simple lenear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
# check if is conv or linear
if len(weight_chunk.shape) == 4:
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x
@@ -158,8 +162,12 @@ class InstantLoRAMidModule(torch.nn.Module):
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.up_shape)
# run a simple lenear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
# check if is conv or linear
if len(weight_chunk.shape) == 4:
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x