diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 38a4014b..ce416093 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -194,6 +194,9 @@ class AdapterConfig: # for ilora self.head_dim: int = kwargs.get('head_dim', 1024) self.num_heads: int = kwargs.get('num_heads', 1) + self.ilora_down: bool = kwargs.get('ilora_down', True) + self.ilora_mid: bool = kwargs.get('ilora_mid', True) + self.ilora_up: bool = kwargs.get('ilora_up', True) class EmbeddingConfig: diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 4ac72b6e..d4ce39a6 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -148,7 +148,8 @@ class CustomAdapter(torch.nn.Module): vision_hidden_size=vision_hidden_size, head_dim=self.config.head_dim, num_heads=self.config.num_heads, - sd=self.sd_ref() + sd=self.sd_ref(), + config=self.config ) elif self.adapter_type == 'text_encoder': if self.config.text_encoder_arch == 't5': diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py index 5b905666..de11bea9 100644 --- a/toolkit/models/ilora2.py +++ b/toolkit/models/ilora2.py @@ -1,6 +1,7 @@ import math import weakref +from toolkit.config_modules import AdapterConfig import torch import torch.nn as nn from typing import TYPE_CHECKING, List, Dict, Any @@ -130,14 +131,23 @@ class InstantLoRAMidModule(torch.nn.Module): self.index = index self.lora_module_ref = weakref.ref(lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module) + + self.do_up = instant_lora_module.config.ilora_up + self.do_down = instant_lora_module.config.ilora_down + self.do_mid = instant_lora_module.config.ilora_mid + + self.down_dim = self.down_shape[1] if self.do_down else 0 + self.mid_dim = self.up_shape[1] if self.do_mid else 0 + self.out_dim = self.up_shape[0] if self.do_down else 0 self.embed = None def down_forward(self, x, *args, **kwargs): + if not self.do_down: + return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) # get the embed self.embed = self.instant_lora_module_ref().img_embeds[self.index] - in_dim = self.down_shape[1] - down_weight = self.embed[:, :in_dim] + down_weight = self.embed[:, :self.down_dim] batch_size = x.shape[0] @@ -169,41 +179,58 @@ class InstantLoRAMidModule(torch.nn.Module): def up_forward(self, x, *args, **kwargs): + if not self.do_up and not self.do_mid: + return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) # get the embed self.embed = self.instant_lora_module_ref().img_embeds[self.index] - in_dim = self.down_shape[1] - mid_dim = self.up_shape[1] - out_dim = self.up_shape[0] - mid_weight = self.embed[:, in_dim:in_dim+mid_dim] - up_weight = self.embed[:, -out_dim:] + if self.do_mid: + mid_weight = self.embed[:, self.down_dim:self.down_dim+self.mid_dim] + else: + mid_weight = None + if self.do_up: + up_weight = self.embed[:, -self.out_dim:] + else: + up_weight = None batch_size = x.shape[0] # unconditional - if up_weight.shape[0] * 2 == batch_size: - up_weight = torch.cat([up_weight] * 2, dim=0) - mid_weight = torch.cat([mid_weight] * 2, dim=0) + if up_weight is not None: + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + if mid_weight is not None: + if mid_weight.shape[0] * 2 == batch_size: + mid_weight = torch.cat([mid_weight] * 2, dim=0) try: if len(x.shape) == 4: # conv - up_weight = up_weight.view(batch_size, -1, 1, 1) - mid_weight = mid_weight.view(batch_size, -1, 1, 1) + if up_weight is not None: + up_weight = up_weight.view(batch_size, -1, 1, 1) + if mid_weight is not None: + mid_weight = mid_weight.view(batch_size, -1, 1, 1) if x.shape[1] != mid_weight.shape[1]: raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") elif len(x.shape) == 2: - up_weight = up_weight.view(batch_size, -1) - mid_weight = mid_weight.view(batch_size, -1) + if up_weight is not None: + up_weight = up_weight.view(batch_size, -1) + if mid_weight is not None: + mid_weight = mid_weight.view(batch_size, -1) if x.shape[1] != mid_weight.shape[1]: raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") else: - up_weight = up_weight.view(batch_size, 1, -1) - mid_weight = mid_weight.view(batch_size, 1, -1) + if up_weight is not None: + up_weight = up_weight.view(batch_size, 1, -1) + if mid_weight is not None: + mid_weight = mid_weight.view(batch_size, 1, -1) if x.shape[2] != mid_weight.shape[2]: raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") # apply mid weight first + if mid_weight is not None: + x = x * mid_weight x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) - x = x * up_weight + if up_weight is not None: + x = x * up_weight except Exception as e: print(e) raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") @@ -220,7 +247,8 @@ class InstantLoRAModule(torch.nn.Module): vision_tokens: int, head_dim: int, num_heads: int, # number of heads in the resampler - sd: 'StableDiffusion' + sd: 'StableDiffusion', + config: AdapterConfig ): super(InstantLoRAModule, self).__init__() # self.linear = torch.nn.Linear(2, 1) @@ -230,6 +258,8 @@ class InstantLoRAModule(torch.nn.Module): self.vision_tokens = vision_tokens self.head_dim = head_dim self.num_heads = num_heads + + self.config: AdapterConfig = config # stores the projection vector. Grabbed by modules self.img_embeds: List[torch.Tensor] = None @@ -260,9 +290,9 @@ class InstantLoRAModule(torch.nn.Module): # linear weight shape is (out_features, in_features) # just doing in dim and out dim - in_dim = down_shape[1] - mid_dim = down_shape[0] - out_dim = up_shape[0] + in_dim = down_shape[1] if self.config.ilora_down else 0 + mid_dim = down_shape[0] if self.config.ilora_mid else 0 + out_dim = up_shape[0] if self.config.ilora_up else 0 module_size = in_dim + mid_dim + out_dim @@ -377,5 +407,8 @@ class InstantLoRAModule(torch.nn.Module): "head_dim": self.head_dim, "vision_tokens": self.vision_tokens, "output_size": self.output_size, + "do_up": self.config.ilora_up, + "do_mid": self.config.ilora_mid, + "do_down": self.config.ilora_down, }