mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Add ability to train a full rank LoRA. (experimental)
This commit is contained in:
@@ -157,6 +157,9 @@ class NetworkConfig:
|
|||||||
elif linear is not None:
|
elif linear is not None:
|
||||||
self.rank: int = linear
|
self.rank: int = linear
|
||||||
self.linear: int = linear
|
self.linear: int = linear
|
||||||
|
else:
|
||||||
|
self.rank: int = 4
|
||||||
|
self.linear: int = 4
|
||||||
self.conv: int = kwargs.get('conv', None)
|
self.conv: int = kwargs.get('conv', None)
|
||||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ CONV_MODULES = [
|
|||||||
'QConv2d',
|
'QConv2d',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class IdentityModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
@@ -81,16 +85,25 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||||
# else:
|
# else:
|
||||||
self.lora_dim = lora_dim
|
self.lora_dim = lora_dim
|
||||||
|
self.full_rank = network.network_type.lower() == "fullrank"
|
||||||
|
|
||||||
if org_module.__class__.__name__ in CONV_MODULES:
|
if org_module.__class__.__name__ in CONV_MODULES:
|
||||||
kernel_size = org_module.kernel_size
|
kernel_size = org_module.kernel_size
|
||||||
stride = org_module.stride
|
stride = org_module.stride
|
||||||
padding = org_module.padding
|
padding = org_module.padding
|
||||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
if self.full_rank:
|
||||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
|
self.lora_down = torch.nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False)
|
||||||
|
self.lora_up = IdentityModule()
|
||||||
|
else:
|
||||||
|
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||||
|
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=use_bias)
|
||||||
else:
|
else:
|
||||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
if self.full_rank:
|
||||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
|
self.lora_down = torch.nn.Linear(in_dim, out_dim, bias=False)
|
||||||
|
self.lora_up = IdentityModule()
|
||||||
|
else:
|
||||||
|
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||||
|
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=use_bias)
|
||||||
|
|
||||||
if type(alpha) == torch.Tensor:
|
if type(alpha) == torch.Tensor:
|
||||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
@@ -100,7 +113,8 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
|||||||
|
|
||||||
# same as microsoft's
|
# same as microsoft's
|
||||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||||
torch.nn.init.zeros_(self.lora_up.weight)
|
if not self.full_rank:
|
||||||
|
torch.nn.init.zeros_(self.lora_up.weight)
|
||||||
|
|
||||||
self.multiplier: Union[float, List[float]] = multiplier
|
self.multiplier: Union[float, List[float]] = multiplier
|
||||||
# wrap the original module so it doesn't get weights updated
|
# wrap the original module so it doesn't get weights updated
|
||||||
@@ -232,6 +246,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
self.is_lumina2 = is_lumina2
|
self.is_lumina2 = is_lumina2
|
||||||
self.network_type = network_type
|
self.network_type = network_type
|
||||||
self.is_assistant_adapter = is_assistant_adapter
|
self.is_assistant_adapter = is_assistant_adapter
|
||||||
|
self.full_rank = network_type.lower() == "fullrank"
|
||||||
if self.network_type.lower() == "dora":
|
if self.network_type.lower() == "dora":
|
||||||
self.module_class = DoRAModule
|
self.module_class = DoRAModule
|
||||||
module_class = DoRAModule
|
module_class = DoRAModule
|
||||||
@@ -426,7 +441,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
if self.full_rank:
|
||||||
|
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape)]
|
||||||
|
else:
|
||||||
|
lora_shape_dict[lora_name] = [list(lora.lora_down.weight.shape), list(lora.lora_up.weight.shape)]
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||||
|
|||||||
@@ -349,7 +349,10 @@ class ToolkitModuleMixin:
|
|||||||
if not self.can_merge_in:
|
if not self.can_merge_in:
|
||||||
return
|
return
|
||||||
# get up/down weight
|
# get up/down weight
|
||||||
up_weight = self.lora_up.weight.clone().float()
|
if self.full_rank:
|
||||||
|
up_weight = None
|
||||||
|
else:
|
||||||
|
up_weight = self.lora_up.weight.clone().float()
|
||||||
down_weight = self.lora_down.weight.clone().float()
|
down_weight = self.lora_down.weight.clone().float()
|
||||||
|
|
||||||
# extract weight from org_module
|
# extract weight from org_module
|
||||||
@@ -374,7 +377,9 @@ class ToolkitModuleMixin:
|
|||||||
scale = scale * self.scalar
|
scale = scale * self.scalar
|
||||||
|
|
||||||
# merge weight
|
# merge weight
|
||||||
if len(weight.size()) == 2:
|
if self.full_rank:
|
||||||
|
weight = weight + multiplier * down_weight * scale
|
||||||
|
elif len(weight.size()) == 2:
|
||||||
# linear
|
# linear
|
||||||
weight = weight + multiplier * (up_weight @ down_weight) * scale
|
weight = weight + multiplier * (up_weight @ down_weight) * scale
|
||||||
elif down_weight.size()[2:4] == (1, 1):
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
|||||||
Reference in New Issue
Block a user