diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py index 5ddc19d6..c46bd0a6 100644 --- a/toolkit/models/ilora2.py +++ b/toolkit/models/ilora2.py @@ -9,8 +9,9 @@ from toolkit.models.clip_fusion import ZipperBlock from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler import sys from toolkit.paths import REPOS_ROOT + sys.path.append(REPOS_ROOT) -from ipadapter.ip_adapter.resampler import Resampler +from ipadapter.ip_adapter.resampler import Resampler from collections import OrderedDict if TYPE_CHECKING: @@ -41,6 +42,7 @@ class MLP(nn.Module): x = x + residual return x + class LoRAGenerator(torch.nn.Module): def __init__( self, @@ -65,7 +67,8 @@ class LoRAGenerator(torch.nn.Module): self.lin_in = nn.Linear(input_size, hidden_size) self.mlp_blocks = nn.Sequential(*[ - MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers) + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in + range(num_mlp_layers) ]) self.head = nn.Linear(hidden_size, head_size, bias=False) self.norm = nn.LayerNorm(head_size) @@ -131,11 +134,11 @@ class InstantLoRAMidModule(torch.nn.Module): self.index = index self.lora_module_ref = weakref.ref(lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module) - + self.do_up = instant_lora_module.config.ilora_up self.do_down = instant_lora_module.config.ilora_down self.do_mid = instant_lora_module.config.ilora_mid - + self.down_dim = self.down_shape[1] if self.do_down else 0 self.mid_dim = self.up_shape[1] if self.do_mid else 0 self.out_dim = self.up_shape[0] if self.do_up else 0 @@ -177,67 +180,74 @@ class InstantLoRAMidModule(torch.nn.Module): return x - def up_forward(self, x, *args, **kwargs): - if not self.do_up and not self.do_mid: + # do mid here + x = self.mid_forward(x, *args, **kwargs) + if not self.do_up: return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) # get the embed self.embed = self.instant_lora_module_ref().img_embeds[self.index] - if self.do_mid: - mid_weight = self.embed[:, self.down_dim:self.down_dim+self.mid_dim] - else: - mid_weight = None - if self.do_up: - up_weight = self.embed[:, -self.out_dim:] - else: - up_weight = None + up_weight = self.embed[:, -self.out_dim:] batch_size = x.shape[0] # unconditional - if up_weight is not None: - if up_weight.shape[0] * 2 == batch_size: - up_weight = torch.cat([up_weight] * 2, dim=0) - if mid_weight is not None: - if mid_weight.shape[0] * 2 == batch_size: - mid_weight = torch.cat([mid_weight] * 2, dim=0) + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) try: if len(x.shape) == 4: # conv - if up_weight is not None: - up_weight = up_weight.view(batch_size, -1, 1, 1) - if mid_weight is not None: - mid_weight = mid_weight.view(batch_size, -1, 1, 1) - if x.shape[1] != mid_weight.shape[1]: - raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") + up_weight = up_weight.view(batch_size, -1, 1, 1) elif len(x.shape) == 2: - if up_weight is not None: - up_weight = up_weight.view(batch_size, -1) - if mid_weight is not None: - mid_weight = mid_weight.view(batch_size, -1) - if x.shape[1] != mid_weight.shape[1]: - raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") + up_weight = up_weight.view(batch_size, -1) else: - if up_weight is not None: - up_weight = up_weight.view(batch_size, 1, -1) - if mid_weight is not None: - mid_weight = mid_weight.view(batch_size, 1, -1) - if x.shape[2] != mid_weight.shape[2]: - raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") - # apply mid weight first - if mid_weight is not None: - x = x * mid_weight + up_weight = up_weight.view(batch_size, 1, -1) x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) - if up_weight is not None: - x = x * up_weight + x = x * up_weight except Exception as e: print(e) raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") return x + def mid_forward(self, x, *args, **kwargs): + if not self.do_mid: + return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) + batch_size = x.shape[0] + # get the embed + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim] + # unconditional + if mid_weight.shape[0] * 2 == batch_size: + mid_weight = torch.cat([mid_weight] * 2, dim=0) + + weight_chunks = torch.chunk(mid_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + if len(x_chunk.shape) == 4: + # conv + weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1) + else: + weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + padding = 0 + if weight_chunk.shape[-1] == 3: + padding = 1 + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x class InstantLoRAModule(torch.nn.Module): @@ -246,7 +256,7 @@ class InstantLoRAModule(torch.nn.Module): vision_hidden_size: int, vision_tokens: int, head_dim: int, - num_heads: int, # number of heads in the resampler + num_heads: int, # number of heads in the resampler sd: 'StableDiffusion', config: AdapterConfig ): @@ -258,7 +268,7 @@ class InstantLoRAModule(torch.nn.Module): self.vision_tokens = vision_tokens self.head_dim = head_dim self.num_heads = num_heads - + self.config: AdapterConfig = config # stores the projection vector. Grabbed by modules @@ -291,11 +301,10 @@ class InstantLoRAModule(torch.nn.Module): # just doing in dim and out dim in_dim = down_shape[1] if self.config.ilora_down else 0 - mid_dim = down_shape[0] if self.config.ilora_mid else 0 + mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0 out_dim = up_shape[0] if self.config.ilora_up else 0 module_size = in_dim + mid_dim + out_dim - output_size += module_size self.embed_lengths.append(module_size) @@ -317,7 +326,6 @@ class InstantLoRAModule(torch.nn.Module): lora_module.lora_up.orig_forward = lora_module.lora_up.forward lora_module.lora_up.forward = instant_module.up_forward - self.output_size = output_size number_formatted_output_size = "{:,}".format(output_size) @@ -377,7 +385,6 @@ class InstantLoRAModule(torch.nn.Module): # print("No keymap found. Using default names") # return - def forward(self, img_embeds): # expand token rank if only rank 2 if len(img_embeds.shape) == 2: @@ -394,10 +401,9 @@ class InstantLoRAModule(torch.nn.Module): # get all the slices start = 0 for length in self.embed_lengths: - self.img_embeds.append(img_embeds[:, start:start+length]) + self.img_embeds.append(img_embeds[:, start:start + length]) start += length - def get_additional_save_metadata(self) -> Dict[str, Any]: # save the weight mapping return { @@ -411,4 +417,3 @@ class InstantLoRAModule(torch.nn.Module): "do_mid": self.config.ilora_mid, "do_down": self.config.ilora_down, } -