diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 925b59e6..134f2071 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -46,12 +46,14 @@ class LoRAGenerator(torch.nn.Module): input_size: int = 768, # projection dimension hidden_size: int = 768, head_size: int = 512, + num_heads: int = 1, num_mlp_layers: int = 1, output_size: int = 768, dropout: float = 0.0 ): super().__init__() self.input_size = input_size + self.num_heads = num_heads self.output_size = output_size self.lin_in = nn.Linear(input_size, hidden_size) @@ -62,11 +64,18 @@ class LoRAGenerator(torch.nn.Module): self.head = nn.Linear(hidden_size, head_size, bias=False) self.norm = nn.LayerNorm(head_size) - self.flatten = nn.Flatten() - self.output = nn.Linear(head_size, self.output_size) - # for each output block. multiply weights by 0.01 - with torch.no_grad(): - self.output.weight.data *= 0.01 + if num_heads == 1: + self.output = nn.Linear(head_size, self.output_size) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + self.output.weight.data *= 0.01 + else: + head_output_size = output_size // num_heads + self.outputs = nn.ModuleList([nn.Linear(head_size, head_output_size) for _ in range(num_heads)]) + # for each output block. multiply weights by 0.01 + with torch.no_grad(): + for output in self.outputs: + output.weight.data *= 0.01 # allow get device @property @@ -86,9 +95,15 @@ class LoRAGenerator(torch.nn.Module): x = self.head(x) x = self.norm(x) - head_output = x + if self.num_heads == 1: + x = self.output(x) + else: + out_chunks = torch.chunk(x, self.num_heads, dim=1) + x = [] + for out_layer, chunk in zip(self.outputs, out_chunks): + x.append(out_layer(chunk)) + x = torch.cat(x, dim=-1) - x = self.output(head_output) return x.squeeze(1) @@ -133,7 +148,10 @@ class InstantLoRAMidModule(torch.nn.Module): weight_chunk = weight_chunk.view(self.down_shape) # check if is conv or linear if len(weight_chunk.shape) == 4: - x_chunk = nn.functional.conv2d(x_chunk, weight_chunk) + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) else: # run a simple linear layer with the down weight x_chunk = x_chunk @ weight_chunk.T @@ -164,7 +182,10 @@ class InstantLoRAMidModule(torch.nn.Module): weight_chunk = weight_chunk.view(self.up_shape) # check if is conv or linear if len(weight_chunk.shape) == 4: - x_chunk = nn.functional.conv2d(x_chunk, weight_chunk) + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) else: # run a simple linear layer with the down weight x_chunk = x_chunk @ weight_chunk.T @@ -239,6 +260,12 @@ class InstantLoRAModule(torch.nn.Module): self.output_size = output_size + # if not evenly divisible, error + if self.output_size % self.num_heads != 0: + raise ValueError("Output size must be divisible by the number of heads") + + self.head_output_size = self.output_size // self.num_heads + if vision_tokens > 1: self.resampler = Resampler( dim=vision_hidden_size,