diff --git a/testing/compare_keys.py b/testing/compare_keys.py new file mode 100644 index 00000000..021178bd --- /dev/null +++ b/testing/compare_keys.py @@ -0,0 +1,89 @@ +import argparse +import os + +import torch +from safetensors.torch import load_file +from collections import OrderedDict +import json +# this was just used to match the vae keys to the diffusers keys +# you probably wont need this. Unless they change them.... again... again +# on second thought, you probably will + +device = torch.device('cpu') +dtype = torch.float32 + +parser = argparse.ArgumentParser() + +# require at lease one config file +parser.add_argument( + 'file_1', + nargs='+', + type=str, + help='Path to first safe tensor file' +) + +parser.add_argument( + 'file_2', + nargs='+', + type=str, + help='Path to second safe tensor file' +) + +args = parser.parse_args() + +find_matches = False + +state_dict_file_1 = load_file(args.file_1[0]) +state_dict_1_keys = list(state_dict_file_1.keys()) + +state_dict_file_2 = load_file(args.file_2[0]) +state_dict_2_keys = list(state_dict_file_2.keys()) +keys_in_both = [] + +keys_not_in_state_dict_2 = [] +for key in state_dict_1_keys: + if key not in state_dict_2_keys: + keys_not_in_state_dict_2.append(key) + +keys_not_in_state_dict_1 = [] +for key in state_dict_2_keys: + if key not in state_dict_1_keys: + keys_not_in_state_dict_1.append(key) + +keys_in_both = [] +for key in state_dict_1_keys: + if key in state_dict_2_keys: + keys_in_both.append(key) + +# sort them +keys_not_in_state_dict_2.sort() +keys_not_in_state_dict_1.sort() +keys_in_both.sort() + + +json_data = { + "both": keys_in_both, + "state_dict_2": keys_not_in_state_dict_2, + "state_dict_1": keys_not_in_state_dict_1 +} +json_data = json.dumps(json_data, indent=4) + +remaining_diffusers_values = OrderedDict() +for key in keys_not_in_state_dict_1: + remaining_diffusers_values[key] = state_dict_file_2[key] + +# print(remaining_diffusers_values.keys()) + +remaining_ldm_values = OrderedDict() +for key in keys_not_in_state_dict_2: + remaining_ldm_values[key] = state_dict_file_1[key] + +# print(json_data) + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +json_save_path = os.path.join(project_root, 'config', 'keys.json') +json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') +json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') + +with open(json_save_path, 'w') as f: + f.write(json_data) \ No newline at end of file diff --git a/toolkit/lycoris_utils.py b/toolkit/lycoris_utils.py index dad5aff8..af11ee9e 100644 --- a/toolkit/lycoris_utils.py +++ b/toolkit/lycoris_utils.py @@ -67,9 +67,6 @@ def extract_conv( return (extract_weight_A, extract_weight_B, diff), 'low rank' -extra_weights = ['lora_unet_conv_in.alpha', 'lora_unet_conv_in.lora_down.weight', 'lora_unet_conv_in.lora_mid.weight', 'lora_unet_conv_in.lora_up.weight', 'lora_unet_conv_out.alpha', 'lora_unet_conv_out.lora_down.weight', 'lora_unet_conv_out.lora_mid.weight', 'lora_unet_conv_out.lora_up.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_down_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.alpha', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_mid_block_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_time_embedding_linear_1.alpha', 'lora_unet_time_embedding_linear_1.lora_down.weight', 'lora_unet_time_embedding_linear_1.lora_up.weight', 'lora_unet_time_embedding_linear_2.alpha', 'lora_unet_time_embedding_linear_2.lora_down.weight', 'lora_unet_time_embedding_linear_2.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_0_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_1_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_2_resnets_2_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_0_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_1_time_emb_proj.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.alpha', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_conv_shortcut.lora_up.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.alpha', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_down.weight', 'lora_unet_up_blocks_3_resnets_2_time_emb_proj.lora_up.weight'] - - def extract_linear( weight: Union[torch.Tensor, nn.Parameter], mode='fixed', @@ -177,7 +174,7 @@ def extract_diff( if module.__class__.__name__ in target_replace_modules: temp[name] = {} for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: continue temp[name][child_name] = child_module.weight elif name in target_replace_names: @@ -190,12 +187,12 @@ def extract_diff( lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_') layer = child_module.__class__.__name__ - if layer in {'Linear', 'Conv2d'}: + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: root_weight = child_module.weight if torch.allclose(root_weight, weights[child_name]): continue - if layer == 'Linear': + if layer == 'Linear' or layer == 'LoRACompatibleLinear': weight, decompose_mode = extract_linear( (child_module.weight - weights[child_name]), mode, @@ -204,7 +201,7 @@ def extract_diff( ) if decompose_mode == 'low rank': extract_a, extract_b, diff = weight - elif layer == 'Conv2d': + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': is_linear = (child_module.weight.shape[2] == 1 and child_module.weight.shape[3] == 1) if not is_linear and linear_only: @@ -258,12 +255,12 @@ def extract_diff( lora_name = lora_name.replace('.', '_') layer = module.__class__.__name__ - if layer in {'Linear', 'Conv2d'}: + if layer in {'Linear', 'LoRACompatibleLinear', 'Conv2d', 'LoRACompatibleConv'}: root_weight = module.weight if torch.allclose(root_weight, weights): continue - if layer == 'Linear': + if layer == 'Linear' or layer == 'LoRACompatibleLinear': weight, decompose_mode = extract_linear( (root_weight - weights), mode, @@ -272,7 +269,7 @@ def extract_diff( ) if decompose_mode == 'low rank': extract_a, extract_b, diff = weight - elif layer == 'Conv2d': + elif layer == 'Conv2d' or layer == 'LoRACompatibleConv': is_linear = ( root_weight.shape[2] == 1 and root_weight.shape[3] == 1 @@ -493,7 +490,8 @@ def merge( for name, module in tqdm(list(root_module.named_modules()), desc=f'Merging {prefix}'): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ not in {'Linear', 'Conv2d'}: + if child_module.__class__.__name__ not in {'Linear', 'LoRACompatibleLinear', 'Conv2d', + 'LoRACompatibleConv'}: continue lora_name = prefix + '.' + name + '.' + child_name lora_name = lora_name.replace('.', '_')