mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-04 01:59:48 +00:00
Add method to do continuious lora merging in for low vram full finetuning.
This commit is contained in:
@@ -522,7 +522,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(save_meta, self.job.name)
|
||||
if not self.is_fine_tuning:
|
||||
if not self.is_fine_tuning and not self.train_config.merge_network_on_save:
|
||||
if self.network is not None:
|
||||
lora_name = self.job.name
|
||||
if self.named_lora:
|
||||
@@ -628,6 +628,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
direct_save=direct_save
|
||||
)
|
||||
else:
|
||||
if self.network is not None and self.train_config.merge_network_on_save:
|
||||
# merge the network weights into a full model and save that
|
||||
if not self.network.can_merge_in:
|
||||
raise ValueError("Network cannot merge in weights. Cannot save full model.")
|
||||
|
||||
print_acc("Merging network weights into full model for saving...")
|
||||
|
||||
self.network.merge_in(merge_weight=1.0)
|
||||
# reset weights to zero
|
||||
self.network.reset_weights()
|
||||
self.network.is_merged_in = False
|
||||
|
||||
print_acc("Done merging network weights.")
|
||||
|
||||
if self.save_config.save_format == "diffusers":
|
||||
# saving as a folder path
|
||||
file_path = file_path.replace('.safetensors', '')
|
||||
@@ -1538,7 +1552,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.hook_before_model_load()
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.is_fine_tuning:
|
||||
if self.is_fine_tuning or self.train_config.merge_network_on_save:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
@@ -1832,7 +1846,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
latest_save_path = self.get_latest_save_path(lora_name)
|
||||
extra_weights = None
|
||||
if latest_save_path is not None:
|
||||
if latest_save_path is not None and not self.train_config.merge_network_on_save:
|
||||
print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
print_acc(f"Loading from {latest_save_path}")
|
||||
extra_weights = self.load_weights(latest_save_path)
|
||||
|
||||
@@ -404,6 +404,15 @@ class ToolkitModuleMixin:
|
||||
# set weight to org_module
|
||||
org_sd[weight_key] = weight.to(weight_device, orig_dtype)
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
def reset_weights(self: Module):
|
||||
# reset the weights to zero
|
||||
org_sd = self.state_dict()
|
||||
for key in org_sd.keys():
|
||||
# only reset lora up
|
||||
if 'lora_up' in key:
|
||||
org_sd[key] = torch.zeros_like(org_sd[key])
|
||||
self.load_state_dict(org_sd)
|
||||
|
||||
def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None):
|
||||
# LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and
|
||||
@@ -811,6 +820,10 @@ class ToolkitNetworkMixin:
|
||||
# not supported
|
||||
self.is_checkpointing = False
|
||||
self._update_checkpointing()
|
||||
|
||||
def reset_weights(self: Network):
|
||||
for module in self.get_all_modules():
|
||||
module.reset_weights()
|
||||
|
||||
def merge_in(self, merge_weight=1.0):
|
||||
if self.network_type.lower() == 'dora':
|
||||
|
||||
Reference in New Issue
Block a user