Add ability to train a full rank LoRA. (experimental)

This commit is contained in:
Jaret Burkett
2025-09-09 07:36:25 -06:00
parent 645046701b
commit af6fdaaaf9
3 changed files with 34 additions and 8 deletions

View File

@@ -157,6 +157,9 @@ class NetworkConfig:
elif linear is not None:
self.rank: int = linear
self.linear: int = linear
else:
self.rank: int = 4
self.linear: int = 4
self.conv: int = kwargs.get('conv', None)
self.alpha: float = kwargs.get('alpha', 1.0)
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)

View File

@@ -38,6 +38,10 @@ CONV_MODULES = [
'QConv2d',
]
class IdentityModule(torch.nn.Module):
def forward(self, x):
return x
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -81,16 +85,25 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
self.full_rank = network.network_type.lower() == "fullrank"
if org_module.__class__.__name__ in CONV_MODULES:
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
if self.full_rank:
self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False)
self.lora_up = IdentityModule()
else:
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
if self.full_rank:
self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False)
self.lora_up = IdentityModule()
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
@@ -100,7 +113,8 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
if not self.full_rank:
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier: Union[float, List[float]] = multiplier
# wrap the original module so it doesn't get weights updated
@@ -232,6 +246,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_lumina2 = is_lumina2
self.network_type = network_type
self.is_assistant_adapter = is_assistant_adapter
self.full_rank = network_type.lower() == "fullrank"
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
module_class = DoRAModule
@@ -426,7 +441,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
except:
pass
else:
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
if self.full_rank:
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)]
else:
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]

View File

@@ -349,7 +349,10 @@ class ToolkitModuleMixin:
if not self.can_merge_in:
return
# get up/down weight
up_weight = self.lora_up.weight.clone().float()
if self.full_rank:
up_weight = None
else:
up_weight = self.lora_up.weight.clone().float()
down_weight = self.lora_down.weight.clone().float()
# extract weight from org_module
@@ -374,7 +377,9 @@ class ToolkitModuleMixin:
scale = scale * self.scalar
# merge weight
if len(weight.size()) == 2:
if self.full_rank:
weight = weight + multiplier * down_weight * scale
elif len(weight.size()) == 2:
# linear
weight = weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):