Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank.

This commit is contained in:
Jaret Burkett
2025-08-24 13:40:25 -06:00
parent 24372b5e35
commit f48d21caee
2 changed files with 68 additions and 9 deletions

View File

@@ -1923,15 +1923,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
for group in optimizer.param_groups:
previous_lrs.append(group['lr'])
try:
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
flush()
except Exception as e:
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
print_acc(e)
load_optimizer = True
if self.network is not None:
if self.network.did_change_weights:
# do not load optimizer if the network changed, it will result in
# a double state that will oom.
load_optimizer = False
if load_optimizer:
try:
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
flush()
except Exception as e:
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
print_acc(e)
# update the optimizer LR from the params
print_acc(f"Updating optimizer LR from params")