Add method to do continuious lora merging in for low vram full finetuning.

This commit is contained in:
Jaret Burkett
2026-02-26 09:00:41 -07:00
parent de7d22c9be
commit 40f995f616
2 changed files with 30 additions and 3 deletions

View File

@@ -522,7 +522,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# prepare meta
save_meta = get_meta_for_safetensors(save_meta, self.job.name)
if not self.is_fine_tuning:
if not self.is_fine_tuning and not self.train_config.merge_network_on_save:
if self.network is not None:
lora_name = self.job.name
if self.named_lora:
@@ -628,6 +628,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
direct_save=direct_save
)
else:
if self.network is not None and self.train_config.merge_network_on_save:
# merge the network weights into a full model and save that
if not self.network.can_merge_in:
raise ValueError("Network cannot merge in weights. Cannot save full model.")
print_acc("Merging network weights into full model for saving...")
self.network.merge_in(merge_weight=1.0)
# reset weights to zero
self.network.reset_weights()
self.network.is_merged_in = False
print_acc("Done merging network weights.")
if self.save_config.save_format == "diffusers":
# saving as a folder path
file_path = file_path.replace('.safetensors', '')
@@ -1538,7 +1552,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.hook_before_model_load()
model_config_to_load = copy.deepcopy(self.model_config)
if self.is_fine_tuning:
if self.is_fine_tuning or self.train_config.merge_network_on_save:
# get the latest checkpoint
# check to see if we have a latest save
latest_save_path = self.get_latest_save_path()
@@ -1832,7 +1846,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
latest_save_path = self.get_latest_save_path(lora_name)
extra_weights = None
if latest_save_path is not None:
if latest_save_path is not None and not self.train_config.merge_network_on_save:
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
print_acc(f"Loading from {latest_save_path}")
extra_weights = self.load_weights(latest_save_path)

View File

@@ -404,6 +404,15 @@ class ToolkitModuleMixin:
# set weight to org_module
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
self.org_module[0].load_state_dict(org_sd)
def reset_weights(self: Module):
# reset the weights to zero
org_sd = self.state_dict()
for key in org_sd.keys():
# only reset lora up
if 'lora_up' in key:
org_sd[key] = torch.zeros_like(org_sd[key])
self.load_state_dict(org_sd)
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
@@ -811,6 +820,10 @@ class ToolkitNetworkMixin:
# not supported
self.is_checkpointing = False
self._update_checkpointing()
def reset_weights(self: Network):
for module in self.get_all_modules():
module.reset_weights()
def merge_in(self, merge_weight=1.0):
if self.network_type.lower() == 'dora':