mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added initial support to initiate lora training from an existing lora
This commit is contained in:
@@ -816,12 +816,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
if len(paths) > 0:
|
if len(paths) > 0:
|
||||||
latest_path = max(paths, key=os.path.getctime)
|
latest_path = max(paths, key=os.path.getctime)
|
||||||
|
|
||||||
|
if latest_path is None and self.network_config is not None and self.network_config.pretrained_lora_path is not None:
|
||||||
|
# set pretrained lora path as load path if we do not have a checkpoint to resume from
|
||||||
|
if os.path.exists(self.network_config.pretrained_lora_path):
|
||||||
|
latest_path = self.network_config.pretrained_lora_path
|
||||||
|
print_acc(f"Using pretrained lora path from config: {latest_path}")
|
||||||
|
else:
|
||||||
|
# no pretrained lora found
|
||||||
|
print_acc(f"Pretrained lora path from config does not exist: {self.network_config.pretrained_lora_path}")
|
||||||
|
|
||||||
return latest_path
|
return latest_path
|
||||||
|
|
||||||
def load_training_state_from_metadata(self, path):
|
def load_training_state_from_metadata(self, path):
|
||||||
if not self.accelerator.is_main_process:
|
if not self.accelerator.is_main_process:
|
||||||
return
|
return
|
||||||
|
if path is not None and self.network_config is not None and path == self.network_config.pretrained_lora_path:
|
||||||
|
# dont load metadata from pretrained lora
|
||||||
|
return
|
||||||
meta = None
|
meta = None
|
||||||
# if path is folder, then it is diffusers
|
# if path is folder, then it is diffusers
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
|
|||||||
@@ -212,6 +212,9 @@ class NetworkConfig:
|
|||||||
|
|
||||||
# ramtorch, doesn't work yet
|
# ramtorch, doesn't work yet
|
||||||
self.layer_offloading = kwargs.get('layer_offloading', False)
|
self.layer_offloading = kwargs.get('layer_offloading', False)
|
||||||
|
|
||||||
|
# start from a pretrained lora
|
||||||
|
self.pretrained_lora_path = kwargs.get('pretrained_lora_path', None)
|
||||||
|
|
||||||
|
|
||||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
||||||
|
|||||||
Reference in New Issue
Block a user