diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 925d34da..e55525aa 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -522,7 +522,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # prepare meta save_meta = get_meta_for_safetensors(save_meta, self.job.name) - if not self.is_fine_tuning: + if not self.is_fine_tuning and not self.train_config.merge_network_on_save: if self.network is not None: lora_name = self.job.name if self.named_lora: @@ -628,6 +628,20 @@ class BaseSDTrainProcess(BaseTrainProcess): direct_save=direct_save ) else: + if self.network is not None and self.train_config.merge_network_on_save: + # merge the network weights into a full model and save that + if not self.network.can_merge_in: + raise ValueError("Network cannot merge in weights. Cannot save full model.") + + print_acc("Merging network weights into full model for saving...") + + self.network.merge_in(merge_weight=1.0) + # reset weights to zero + self.network.reset_weights() + self.network.is_merged_in = False + + print_acc("Done merging network weights.") + if self.save_config.save_format == "diffusers": # saving as a folder path file_path = file_path.replace('.safetensors', '') @@ -1538,7 +1552,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.hook_before_model_load() model_config_to_load = copy.deepcopy(self.model_config) - if self.is_fine_tuning: + if self.is_fine_tuning or self.train_config.merge_network_on_save: # get the latest checkpoint # check to see if we have a latest save latest_save_path = self.get_latest_save_path() @@ -1832,7 +1846,7 @@ class BaseSDTrainProcess(BaseTrainProcess): latest_save_path = self.get_latest_save_path(lora_name) extra_weights = None - if latest_save_path is not None: + if latest_save_path is not None and not self.train_config.merge_network_on_save: print_acc(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####") print_acc(f"Loading from {latest_save_path}") extra_weights = self.load_weights(latest_save_path) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index b8556f12..4546e573 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -404,6 +404,15 @@ class ToolkitModuleMixin: # set weight to org_module org_sd[weight_key] = weight.to(weight_device, orig_dtype) self.org_module[0].load_state_dict(org_sd) + + def reset_weights(self: Module): + # reset the weights to zero + org_sd = self.state_dict() + for key in org_sd.keys(): + # only reset lora up + if 'lora_up' in key: + org_sd[key] = torch.zeros_like(org_sd[key]) + self.load_state_dict(org_sd) def setup_lorm(self: Module, state_dict: Optional[Dict[str, Any]] = None): # LoRM (Low Rank Middle) is a method reduce the number of parameters in a module while keeping the inputs and @@ -811,6 +820,10 @@ class ToolkitNetworkMixin: # not supported self.is_checkpointing = False self._update_checkpointing() + + def reset_weights(self: Network): + for module in self.get_all_modules(): + module.reset_weights() def merge_in(self, merge_weight=1.0): if self.network_type.lower() == 'dora':