Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank.

This commit is contained in:
Jaret Burkett
2025-08-24 13:40:25 -06:00
parent 24372b5e35
commit f48d21caee
2 changed files with 68 additions and 9 deletions

View File

@@ -46,6 +46,15 @@ ExtractMode = Union[
'percentage'
]
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
def broadcast_and_multiply(tensor, multiplier):
# Determine the number of dimensions required
@@ -434,6 +443,8 @@ class ToolkitNetworkMixin:
self.module_losses: List[torch.Tensor] = []
self.lorm_train_mode: Literal['local', None] = None
self.can_merge_in = not is_lorm
# will prevent optimizer from loading as it will have double states
self.did_change_weights = False
def get_keymap(self: Network, force_weight_mapping=False):
use_weight_mapping = False
@@ -634,6 +645,46 @@ class ToolkitNetworkMixin:
if key not in current_state_dict:
extra_dict[key] = load_sd[key]
to_delete.append(key)
elif "lora_down" in key or "lora_up" in key:
# handle expanding/shrinking LoRA (linear only)
if len(load_sd[key].shape) == 2:
load_value = load_sd[key] # from checkpoint
blank_val = current_state_dict[key] # shape we need in the target model
tgt_h, tgt_w = blank_val.shape
src_h, src_w = load_value.shape
if (src_h, src_w) == (tgt_h, tgt_w):
# shapes already match: keep original
pass
elif "lora_down" in key and src_h < tgt_h:
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
new_val[:src_h, :src_w] = load_value # src_w should already match
load_sd[key] = new_val
self.did_change_weights = True
elif "lora_up" in key and src_w < tgt_w:
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
new_val[:src_h, :src_w] = load_value # src_h should already match
load_sd[key] = new_val
self.did_change_weights = True
elif "lora_down" in key and src_h > tgt_h:
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
load_sd[key] = load_value[:tgt_h, :tgt_w]
self.did_change_weights = True
elif "lora_up" in key and src_w > tgt_w:
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
load_sd[key] = load_value[:tgt_h, :tgt_w]
self.did_change_weights = True
else:
# unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics)
raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}")
for key in to_delete:
del load_sd[key]