mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank.
This commit is contained in:
@@ -1923,15 +1923,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for group in optimizer.param_groups:
|
||||
previous_lrs.append(group['lr'])
|
||||
|
||||
try:
|
||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(e)
|
||||
load_optimizer = True
|
||||
if self.network is not None:
|
||||
if self.network.did_change_weights:
|
||||
# do not load optimizer if the network changed, it will result in
|
||||
# a double state that will oom.
|
||||
load_optimizer = False
|
||||
|
||||
if load_optimizer:
|
||||
try:
|
||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
flush()
|
||||
except Exception as e:
|
||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||
print_acc(e)
|
||||
|
||||
# update the optimizer LR from the params
|
||||
print_acc(f"Updating optimizer LR from params")
|
||||
|
||||
@@ -46,6 +46,15 @@ ExtractMode = Union[
|
||||
'percentage'
|
||||
]
|
||||
|
||||
printed_messages = []
|
||||
|
||||
|
||||
def print_once(msg):
|
||||
global printed_messages
|
||||
if msg not in printed_messages:
|
||||
print(msg)
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
def broadcast_and_multiply(tensor, multiplier):
|
||||
# Determine the number of dimensions required
|
||||
@@ -434,6 +443,8 @@ class ToolkitNetworkMixin:
|
||||
self.module_losses: List[torch.Tensor] = []
|
||||
self.lorm_train_mode: Literal['local', None] = None
|
||||
self.can_merge_in = not is_lorm
|
||||
# will prevent optimizer from loading as it will have double states
|
||||
self.did_change_weights = False
|
||||
|
||||
def get_keymap(self: Network, force_weight_mapping=False):
|
||||
use_weight_mapping = False
|
||||
@@ -634,6 +645,46 @@ class ToolkitNetworkMixin:
|
||||
if key not in current_state_dict:
|
||||
extra_dict[key] = load_sd[key]
|
||||
to_delete.append(key)
|
||||
elif "lora_down" in key or "lora_up" in key:
|
||||
# handle expanding/shrinking LoRA (linear only)
|
||||
if len(load_sd[key].shape) == 2:
|
||||
load_value = load_sd[key] # from checkpoint
|
||||
blank_val = current_state_dict[key] # shape we need in the target model
|
||||
tgt_h, tgt_w = blank_val.shape
|
||||
src_h, src_w = load_value.shape
|
||||
|
||||
if (src_h, src_w) == (tgt_h, tgt_w):
|
||||
# shapes already match: keep original
|
||||
pass
|
||||
|
||||
elif "lora_down" in key and src_h < tgt_h:
|
||||
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
|
||||
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
|
||||
new_val[:src_h, :src_w] = load_value # src_w should already match
|
||||
load_sd[key] = new_val
|
||||
self.did_change_weights = True
|
||||
|
||||
elif "lora_up" in key and src_w < tgt_w:
|
||||
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
|
||||
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
|
||||
new_val[:src_h, :src_w] = load_value # src_h should already match
|
||||
load_sd[key] = new_val
|
||||
self.did_change_weights = True
|
||||
|
||||
elif "lora_down" in key and src_h > tgt_h:
|
||||
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
|
||||
load_sd[key] = load_value[:tgt_h, :tgt_w]
|
||||
self.did_change_weights = True
|
||||
|
||||
elif "lora_up" in key and src_w > tgt_w:
|
||||
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
|
||||
load_sd[key] = load_value[:tgt_h, :tgt_w]
|
||||
self.did_change_weights = True
|
||||
|
||||
else:
|
||||
# unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics)
|
||||
raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}")
|
||||
|
||||
for key in to_delete:
|
||||
del load_sd[key]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user