mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank.
This commit is contained in:
@@ -1923,15 +1923,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for group in optimizer.param_groups:
|
||||
previous_lrs.append(group['lr'])
|
||||
|
||||
try:
|
||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(e)
|
||||
load_optimizer = True
|
||||
if self.network is not None:
|
||||
if self.network.did_change_weights:
|
||||
# do not load optimizer if the network changed, it will result in
|
||||
# a double state that will oom.
|
||||
load_optimizer = False
|
||||
|
||||
if load_optimizer:
|
||||
try:
|
||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(e)
|
||||
|
||||
# update the optimizer LR from the params
|
||||
print_acc(f"Updating optimizer LR from params")
|
||||
|
||||
Reference in New Issue
Block a user