Add method to do continuious lora merging in for low vram full finetuning.

This commit is contained in:
Jaret Burkett
2026-02-26 09:00:41 -07:00
parent de7d22c9be
commit 40f995f616
2 changed files with 30 additions and 3 deletions

View File

@@ -404,6 +404,15 @@ class ToolkitModuleMixin:
# set weight to org_module
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def reset_weights(self: Module):
# reset the weights to zero
org_sd = self.state_dict()
for key in org_sd.keys():
# only reset lora up
if 'lora_up' in key:
org_sd[key] = torch.zeros_like(org_sd[key])
self.load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
@@ -811,6 +820,10 @@ class ToolkitNetworkMixin:
# not supported
self.is_checkpointing = False
self._update_checkpointing()
def reset_weights(self: Network):
for module in self.get_all_modules():
module.reset_weights()
def merge_in(self, merge_weight=1.0):
if self.network_type.lower() == 'dora':