mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Upgrade a LoRA rank if the new one is larger so users can increase the rank on an exiting training job and continue training at a higher rank.
This commit is contained in:
@@ -1923,15 +1923,23 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
previous_lrs.append(group['lr'])
|
previous_lrs.append(group['lr'])
|
||||||
|
|
||||||
try:
|
load_optimizer = True
|
||||||
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
if self.network is not None:
|
||||||
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
if self.network.did_change_weights:
|
||||||
optimizer.load_state_dict(optimizer_state_dict)
|
# do not load optimizer if the network changed, it will result in
|
||||||
del optimizer_state_dict
|
# a double state that will oom.
|
||||||
flush()
|
load_optimizer = False
|
||||||
except Exception as e:
|
|
||||||
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
if load_optimizer:
|
||||||
print_acc(e)
|
try:
|
||||||
|
print_acc(f"Loading optimizer state from {optimizer_state_file_path}")
|
||||||
|
optimizer_state_dict = torch.load(optimizer_state_file_path, weights_only=True)
|
||||||
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
|
del optimizer_state_dict
|
||||||
|
flush()
|
||||||
|
except Exception as e:
|
||||||
|
print_acc(f"Failed to load optimizer state from {optimizer_state_file_path}")
|
||||||
|
print_acc(e)
|
||||||
|
|
||||||
# update the optimizer LR from the params
|
# update the optimizer LR from the params
|
||||||
print_acc(f"Updating optimizer LR from params")
|
print_acc(f"Updating optimizer LR from params")
|
||||||
|
|||||||
@@ -46,6 +46,15 @@ ExtractMode = Union[
|
|||||||
'percentage'
|
'percentage'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
printed_messages = []
|
||||||
|
|
||||||
|
|
||||||
|
def print_once(msg):
|
||||||
|
global printed_messages
|
||||||
|
if msg not in printed_messages:
|
||||||
|
print(msg)
|
||||||
|
printed_messages.append(msg)
|
||||||
|
|
||||||
|
|
||||||
def broadcast_and_multiply(tensor, multiplier):
|
def broadcast_and_multiply(tensor, multiplier):
|
||||||
# Determine the number of dimensions required
|
# Determine the number of dimensions required
|
||||||
@@ -434,6 +443,8 @@ class ToolkitNetworkMixin:
|
|||||||
self.module_losses: List[torch.Tensor] = []
|
self.module_losses: List[torch.Tensor] = []
|
||||||
self.lorm_train_mode: Literal['local', None] = None
|
self.lorm_train_mode: Literal['local', None] = None
|
||||||
self.can_merge_in = not is_lorm
|
self.can_merge_in = not is_lorm
|
||||||
|
# will prevent optimizer from loading as it will have double states
|
||||||
|
self.did_change_weights = False
|
||||||
|
|
||||||
def get_keymap(self: Network, force_weight_mapping=False):
|
def get_keymap(self: Network, force_weight_mapping=False):
|
||||||
use_weight_mapping = False
|
use_weight_mapping = False
|
||||||
@@ -634,6 +645,46 @@ class ToolkitNetworkMixin:
|
|||||||
if key not in current_state_dict:
|
if key not in current_state_dict:
|
||||||
extra_dict[key] = load_sd[key]
|
extra_dict[key] = load_sd[key]
|
||||||
to_delete.append(key)
|
to_delete.append(key)
|
||||||
|
elif "lora_down" in key or "lora_up" in key:
|
||||||
|
# handle expanding/shrinking LoRA (linear only)
|
||||||
|
if len(load_sd[key].shape) == 2:
|
||||||
|
load_value = load_sd[key] # from checkpoint
|
||||||
|
blank_val = current_state_dict[key] # shape we need in the target model
|
||||||
|
tgt_h, tgt_w = blank_val.shape
|
||||||
|
src_h, src_w = load_value.shape
|
||||||
|
|
||||||
|
if (src_h, src_w) == (tgt_h, tgt_w):
|
||||||
|
# shapes already match: keep original
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif "lora_down" in key and src_h < tgt_h:
|
||||||
|
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
|
||||||
|
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
|
||||||
|
new_val[:src_h, :src_w] = load_value # src_w should already match
|
||||||
|
load_sd[key] = new_val
|
||||||
|
self.did_change_weights = True
|
||||||
|
|
||||||
|
elif "lora_up" in key and src_w < tgt_w:
|
||||||
|
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
|
||||||
|
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
|
||||||
|
new_val[:src_h, :src_w] = load_value # src_h should already match
|
||||||
|
load_sd[key] = new_val
|
||||||
|
self.did_change_weights = True
|
||||||
|
|
||||||
|
elif "lora_down" in key and src_h > tgt_h:
|
||||||
|
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
|
||||||
|
load_sd[key] = load_value[:tgt_h, :tgt_w]
|
||||||
|
self.did_change_weights = True
|
||||||
|
|
||||||
|
elif "lora_up" in key and src_w > tgt_w:
|
||||||
|
print_once(f"Shrinking {key} from {load_value.shape} to {blank_val.shape}")
|
||||||
|
load_sd[key] = load_value[:tgt_h, :tgt_w]
|
||||||
|
self.did_change_weights = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
# unexpected mismatch (e.g., both dims differ in a way that doesn't match lora_up/down semantics)
|
||||||
|
raise ValueError(f"Unhandled LoRA shape change for {key}: src={load_value.shape}, tgt={blank_val.shape}")
|
||||||
|
|
||||||
for key in to_delete:
|
for key in to_delete:
|
||||||
del load_sd[key]
|
del load_sd[key]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user