Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank.

This commit is contained in:
Jaret Burkett
2025-08-24 13:40:25 -06:00
parent 24372b5e35
commit f48d21caee
2 changed files with 68 additions and 9 deletions

View File

@@ -1923,15 +1923,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
for group in optimizer.param_groups: for group in optimizer.param_groups:
previous_lrs.append(group['lr']) previous_lrs.append(group['lr'])
try: load_optimizer = True
print_acc(f"Loading optimizer state from {optimizer_state_file_path}") if self.network is not None:
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) if self.network.did_change_weights:
optimizer.load_state_dict(optimizer_state_dict) # do not load optimizer if the network changed, it will result in
del optimizer_state_dict # a double state that will oom.
flush() load_optimizer = False
except Exception as e:
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") if load_optimizer:
print_acc(e) try:
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
flush()
except Exception as e:
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
print_acc(e)
# update the optimizer LR from the params # update the optimizer LR from the params
print_acc(f"Updating optimizer LR from params") print_acc(f"Updating optimizer LR from params")

View File

@@ -46,6 +46,15 @@ ExtractMode = Union[
'percentage' 'percentage'
] ]
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
def broadcast_and_multiply(tensor, multiplier): def broadcast_and_multiply(tensor, multiplier):
# Determine the number of dimensions required # Determine the number of dimensions required
@@ -434,6 +443,8 @@ class ToolkitNetworkMixin:
self.module_losses: List[torch.Tensor] = [] self.module_losses: List[torch.Tensor] = []
self.lorm_train_mode: Literal['local', None] = None self.lorm_train_mode: Literal['local', None] = None
self.can_merge_in = not is_lorm self.can_merge_in = not is_lorm
# will prevent optimizer from loading as it will have double states
self.did_change_weights = False
def get_keymap(self: Network, force_weight_mapping=False): def get_keymap(self: Network, force_weight_mapping=False):
use_weight_mapping = False use_weight_mapping = False
@@ -634,6 +645,46 @@ class ToolkitNetworkMixin:
if key not in current_state_dict: if key not in current_state_dict:
extra_dict[key] = load_sd[key] extra_dict[key] = load_sd[key]
to_delete.append(key) to_delete.append(key)
elif "lora_down" in key or "lora_up" in key:
# handle expanding/shrinking LoRA (linear only)
if len(load_sd[key].shape) == 2:
load_value = load_sd[key] # from checkpoint
blank_val = current_state_dict[key] # shape we need in the target model
tgt_h, tgt_w = blank_val.shape
src_h, src_w = load_value.shape
if (src_h, src_w) == (tgt_h, tgt_w):
# shapes already match: keep original
pass
elif "lora_down" in key and src_h < tgt_h:
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
new_val[:src_h, :src_w] = load_value # src_w should already match
load_sd[key] = new_val
self.did_change_weights = True
elif "lora_up" in key and src_w < tgt_w:
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
new_val[:src_h, :src_w] = load_value # src_h should already match
load_sd[key] = new_val
self.did_change_weights = True
elif "lora_down" in key and src_h > tgt_h:
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
load_sd[key] = load_value[:tgt_h, :tgt_w]
self.did_change_weights = True
elif "lora_up" in key and src_w > tgt_w:
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
load_sd[key] = load_value[:tgt_h, :tgt_w]
self.did_change_weights = True
else:
# unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics)
raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}")
for key in to_delete: for key in to_delete:
del load_sd[key] del load_sd[key]