Added initial support to initiate lora training from an existing lora

This commit is contained in:
Jaret Burkett
2025-12-22 12:49:15 -07:00
parent 91342853c1
commit 87edca1b2b
2 changed files with 15 additions and 0 deletions

View File

@@ -816,12 +816,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
if len(paths) > 0: if len(paths) > 0:
latest_path = max(paths, key=os.path.getctime) latest_path = max(paths, key=os.path.getctime)
if latest_path is None and self.network_config is not None and self.network_config.pretrained_lora_path is not None:
# set pretrained lora path as load path if we do not have a checkpoint to resume from
if os.path.exists(self.network_config.pretrained_lora_path):
latest_path = self.network_config.pretrained_lora_path
print_acc(f"Using pretrained lora path from config: {latest_path}")
else:
# no pretrained lora found
print_acc(f"Pretrained lora path from config does not exist: {self.network_config.pretrained_lora_path}")
return latest_path return latest_path
def load_training_state_from_metadata(self, path): def load_training_state_from_metadata(self, path):
if not self.accelerator.is_main_process: if not self.accelerator.is_main_process:
return return
if path is not None and self.network_config is not None and path == self.network_config.pretrained_lora_path:
# dont load metadata from pretrained lora
return
meta = None meta = None
# if path is folder, then it is diffusers # if path is folder, then it is diffusers
if os.path.isdir(path): if os.path.isdir(path):

View File

@@ -212,6 +212,9 @@ class NetworkConfig:
# ramtorch, doesn't work yet # ramtorch, doesn't work yet
self.layer_offloading = kwargs.get('layer_offloading', False) self.layer_offloading = kwargs.get('layer_offloading', False)
# start from a pretrained lora
self.pretrained_lora_path = kwargs.get('pretrained_lora_path', None)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']