diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 912c241f..2526d86e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -157,6 +157,9 @@ class NetworkConfig: elif linear is not None: self.rank: int = linear self.linear: int = linear + else: + self.rank: int = 4 + self.linear: int = 4 self.conv: int = kwargs.get('conv', None) self.alpha: float = kwargs.get('alpha', 1.0) self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index cd443735..2c4c11a8 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -38,6 +38,10 @@ CONV_MODULES = [ 'QConv2d', ] +class IdentityModule(torch.nn.Module): + def forward(self, x): + return x + class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -81,16 +85,25 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}") # else: self.lora_dim = lora_dim + self.full_rank = network.network_type.lower() == "fullrank" if org_module.__class__.__name__ in CONV_MODULES: kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) + if self.full_rank: + self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False) + self.lora_up = IdentityModule() + else: + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) + if self.full_rank: + self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False) + self.lora_up = IdentityModule() + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() # without casting, bf16 causes error @@ -100,7 +113,8 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # same as microsoft's torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) + if not self.full_rank: + torch.nn.init.zeros_(self.lora_up.weight) self.multiplier: Union[float, List[float]] = multiplier # wrap the original module so it doesn't get weights updated @@ -232,6 +246,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_lumina2 = is_lumina2 self.network_type = network_type self.is_assistant_adapter = is_assistant_adapter + self.full_rank = network_type.lower() == "fullrank" if self.network_type.lower() == "dora": self.module_class = DoRAModule module_class = DoRAModule @@ -426,7 +441,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): except: pass else: - lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] + if self.full_rank: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)] + else: + lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)] return loras, skipped text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index a25b6255..617f4baa 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -349,7 +349,10 @@ class ToolkitModuleMixin: if not self.can_merge_in: return # get up/down weight - up_weight = self.lora_up.weight.clone().float() + if self.full_rank: + up_weight = None + else: + up_weight = self.lora_up.weight.clone().float() down_weight = self.lora_down.weight.clone().float() # extract weight from org_module @@ -374,7 +377,9 @@ class ToolkitModuleMixin: scale = scale * self.scalar # merge weight - if len(weight.size()) == 2: + if self.full_rank: + weight = weight + multiplier * down_weight * scale + elif len(weight.size()) == 2: # linear weight = weight + multiplier * (up_weight @ down_weight) * scale elif down_weight.size()[2:4] == (1, 1):