mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add ability to train a full rank LoRA. (experimental)
This commit is contained in:
@@ -157,6 +157,9 @@ class NetworkConfig:
|
||||
elif linear is not None:
|
||||
self.rank: int = linear
|
||||
self.linear: int = linear
|
||||
else:
|
||||
self.rank: int = 4
|
||||
self.linear: int = 4
|
||||
self.conv: int = kwargs.get('conv', None)
|
||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
|
||||
@@ -38,6 +38,10 @@ CONV_MODULES = [
|
||||
'QConv2d',
|
||||
]
|
||||
|
||||
class IdentityModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -81,16 +85,25 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
self.full_rank = network.network_type.lower() == "fullrank"
|
||||
|
||||
if org_module.__class__.__name__ in CONV_MODULES:
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
|
||||
if self.full_rank:
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = IdentityModule()
|
||||
else:
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
|
||||
if self.full_rank:
|
||||
self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False)
|
||||
self.lora_up = IdentityModule()
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
@@ -100,7 +113,8 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
if not self.full_rank:
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
self.multiplier: Union[float, List[float]] = multiplier
|
||||
# wrap the original module so it doesn't get weights updated
|
||||
@@ -232,6 +246,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_lumina2 = is_lumina2
|
||||
self.network_type = network_type
|
||||
self.is_assistant_adapter = is_assistant_adapter
|
||||
self.full_rank = network_type.lower() == "fullrank"
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
@@ -426,7 +441,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
||||
if self.full_rank:
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)]
|
||||
else:
|
||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
||||
return loras, skipped
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
@@ -349,7 +349,10 @@ class ToolkitModuleMixin:
|
||||
if not self.can_merge_in:
|
||||
return
|
||||
# get up/down weight
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
if self.full_rank:
|
||||
up_weight = None
|
||||
else:
|
||||
up_weight = self.lora_up.weight.clone().float()
|
||||
down_weight = self.lora_down.weight.clone().float()
|
||||
|
||||
# extract weight from org_module
|
||||
@@ -374,7 +377,9 @@ class ToolkitModuleMixin:
|
||||
scale = scale * self.scalar
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
if self.full_rank:
|
||||
weight = weight + multiplier * down_weight * scale
|
||||
elif len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
|
||||
Reference in New Issue
Block a user