mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Use peft format for flux loras so they are compatible with diffusers. allow loading an assistant lora
This commit is contained in:
@@ -204,7 +204,6 @@ class ToolkitModuleMixin:
|
||||
|
||||
return lx * scale
|
||||
|
||||
|
||||
def lorm_forward(self: Network, x, *args, **kwargs):
|
||||
network: Network = self.network_ref()
|
||||
if not network.is_active:
|
||||
@@ -492,6 +491,24 @@ class ToolkitNetworkMixin:
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
save_dict[key] = v
|
||||
|
||||
if self.peft_format:
|
||||
# lora_down = lora_A
|
||||
# lora_up = lora_B
|
||||
# no alpha
|
||||
|
||||
new_save_dict = {}
|
||||
for key, value in save_dict.items():
|
||||
if key.endswith('.alpha'):
|
||||
continue
|
||||
new_key = key
|
||||
new_key = new_key.replace('lora_down', 'lora_A')
|
||||
new_key = new_key.replace('lora_up', 'lora_B')
|
||||
# replace all $$ with .
|
||||
new_key = new_key.replace('$$', '.')
|
||||
new_save_dict[new_key] = value
|
||||
|
||||
save_dict = new_save_dict
|
||||
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(state_dict, metadata)
|
||||
@@ -519,6 +536,20 @@ class ToolkitNetworkMixin:
|
||||
# replace old double __ with single _
|
||||
if self.is_pixart:
|
||||
load_key = load_key.replace('__', '_')
|
||||
|
||||
if self.peft_format:
|
||||
# lora_down = lora_A
|
||||
# lora_up = lora_B
|
||||
# no alpha
|
||||
if load_key.endswith('.alpha'):
|
||||
continue
|
||||
load_key = load_key.replace('lora_A', 'lora_down')
|
||||
load_key = load_key.replace('lora_B', 'lora_up')
|
||||
# replace all . with $$
|
||||
load_key = load_key.replace('.', '$$')
|
||||
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
|
||||
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
|
||||
|
||||
load_sd[load_key] = value
|
||||
|
||||
# extract extra items from state dict
|
||||
@@ -533,7 +564,8 @@ class ToolkitNetworkMixin:
|
||||
del load_sd[key]
|
||||
|
||||
print(f"Missing keys: {to_delete}")
|
||||
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete):
|
||||
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (
|
||||
len(to_delete) == 1 and 'emb_params' in to_delete):
|
||||
print(" Attempting to load with forced keymap")
|
||||
return self.load_weights(file, force_weight_mapping=True)
|
||||
|
||||
@@ -657,4 +689,3 @@ class ToolkitNetworkMixin:
|
||||
params_reduced += (num_orig_module_params - num_lorem_params)
|
||||
|
||||
return params_reduced
|
||||
|
||||
|
||||
Reference in New Issue
Block a user