diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 2c407160..f90fc8f8 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -816,12 +816,24 @@ class BaseSDTrainProcess(BaseTrainProcess): if len(paths) > 0: latest_path = max(paths, key=os.path.getctime) + + if latest_path is None and self.network_config is not None and self.network_config.pretrained_lora_path is not None: + # set pretrained lora path as load path if we do not have a checkpoint to resume from + if os.path.exists(self.network_config.pretrained_lora_path): + latest_path = self.network_config.pretrained_lora_path + print_acc(f"Using pretrained lora path from config: {latest_path}") + else: + # no pretrained lora found + print_acc(f"Pretrained lora path from config does not exist: {self.network_config.pretrained_lora_path}") return latest_path def load_training_state_from_metadata(self, path): if not self.accelerator.is_main_process: return + if path is not None and self.network_config is not None and path == self.network_config.pretrained_lora_path: + # dont load metadata from pretrained lora + return meta = None # if path is folder, then it is diffusers if os.path.isdir(path): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index d29de574..7d2895cd 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -212,6 +212,9 @@ class NetworkConfig: # ramtorch, doesn't work yet self.layer_offloading = kwargs.get('layer_offloading', False) + + # start from a pretrained lora + self.pretrained_lora_path = kwargs.get('pretrained_lora_path', None) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']