From f48d21caee294428d7b1dd32139d94bf41dae239 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 24 Aug 2025 13:40:25 -0600 Subject: [PATCH] Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank. --- jobs/process/BaseSDTrainProcess.py | 26 +++++++++------ toolkit/network_mixins.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 84c941bb..43508d4f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1923,15 +1923,23 @@ class BaseSDTrainProcess(BaseTrainProcess): for group in optimizer.param_groups: previous_lrs.append(group['lr']) - try: - print_acc(f"Loading optimizer state from {optimizer_state_file_path}") - optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) - optimizer.load_state_dict(optimizer_state_dict) - del optimizer_state_dict - flush() - except Exception as e: - print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") - print_acc(e) + load_optimizer = True + if self.network is not None: + if self.network.did_change_weights: + # do not load optimizer if the network changed, it will result in + # a double state that will oom. + load_optimizer = False + + if load_optimizer: + try: + print_acc(f"Loading optimizer state from {optimizer_state_file_path}") + optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True) + optimizer.load_state_dict(optimizer_state_dict) + del optimizer_state_dict + flush() + except Exception as e: + print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}") + print_acc(e) # update the optimizer LR from the params print_acc(f"Updating optimizer LR from params") diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 59c15d3d..a25b6255 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -46,6 +46,15 @@ ExtractMode = Union[ 'percentage' ] +printed_messages = [] + + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + def broadcast_and_multiply(tensor, multiplier): # Determine the number of dimensions required @@ -434,6 +443,8 @@ class ToolkitNetworkMixin: self.module_losses: List[torch.Tensor] = [] self.lorm_train_mode: Literal['local', None] = None self.can_merge_in = not is_lorm + # will prevent optimizer from loading as it will have double states + self.did_change_weights = False def get_keymap(self: Network, force_weight_mapping=False): use_weight_mapping = False @@ -634,6 +645,46 @@ class ToolkitNetworkMixin: if key not in current_state_dict: extra_dict[key] = load_sd[key] to_delete.append(key) + elif "lora_down" in key or "lora_up" in key: + # handle expanding/shrinking LoRA (linear only) + if len(load_sd[key].shape) == 2: + load_value = load_sd[key] # from checkpoint + blank_val = current_state_dict[key] # shape we need in the target model + tgt_h, tgt_w = blank_val.shape + src_h, src_w = load_value.shape + + if (src_h, src_w) == (tgt_h, tgt_w): + # shapes already match: keep original + pass + + elif "lora_down" in key and src_h < tgt_h: + print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}") + new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype) + new_val[:src_h, :src_w] = load_value # src_w should already match + load_sd[key] = new_val + self.did_change_weights = True + + elif "lora_up" in key and src_w < tgt_w: + print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}") + new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype) + new_val[:src_h, :src_w] = load_value # src_h should already match + load_sd[key] = new_val + self.did_change_weights = True + + elif "lora_down" in key and src_h > tgt_h: + print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}") + load_sd[key] = load_value[:tgt_h, :tgt_w] + self.did_change_weights = True + + elif "lora_up" in key and src_w > tgt_w: + print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}") + load_sd[key] = load_value[:tgt_h, :tgt_w] + self.did_change_weights = True + + else: + # unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics) + raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}") + for key in to_delete: del load_sd[key]