Use peft format for flux loras so they are compatible with diffusers. allow loading an assistant lora

This commit is contained in:
Jaret Burkett
2024-08-05 14:34:37 -06:00
parent edb7e827ee
commit 187663ab55
4 changed files with 87 additions and 6 deletions

View File

@@ -204,7 +204,6 @@ class ToolkitModuleMixin:
return lx * scale
def lorm_forward(self: Network, x, *args, **kwargs):
network: Network = self.network_ref()
if not network.is_active:
@@ -492,6 +491,24 @@ class ToolkitNetworkMixin:
v = v.detach().clone().to("cpu").to(dtype)
save_dict[key] = v
if self.peft_format:
# lora_down = lora_A
# lora_up = lora_B
# no alpha
new_save_dict = {}
for key, value in save_dict.items():
if key.endswith('.alpha'):
continue
new_key = key
new_key = new_key.replace('lora_down', 'lora_A')
new_key = new_key.replace('lora_up', 'lora_B')
# replace all $$ with .
new_key = new_key.replace('$$', '.')
new_save_dict[new_key] = value
save_dict = new_save_dict
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(state_dict, metadata)
@@ -519,6 +536,20 @@ class ToolkitNetworkMixin:
# replace old double __ with single _
if self.is_pixart:
load_key = load_key.replace('__', '_')
if self.peft_format:
# lora_down = lora_A
# lora_up = lora_B
# no alpha
if load_key.endswith('.alpha'):
continue
load_key = load_key.replace('lora_A', 'lora_down')
load_key = load_key.replace('lora_B', 'lora_up')
# replace all . with $$
load_key = load_key.replace('.', '$$')
load_key = load_key.replace('$$lora_down$$', '.lora_down.')
load_key = load_key.replace('$$lora_up$$', '.lora_up.')
load_sd[load_key] = value
# extract extra items from state dict
@@ -533,7 +564,8 @@ class ToolkitNetworkMixin:
del load_sd[key]
print(f"Missing keys: {to_delete}")
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete):
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (
len(to_delete) == 1 and 'emb_params' in to_delete):
print(" Attempting to load with forced keymap")
return self.load_weights(file, force_weight_mapping=True)
@@ -657,4 +689,3 @@ class ToolkitNetworkMixin:
params_reduced += (num_orig_module_params - num_lorem_params)
return params_reduced