diff --git a/testing/compare_keys.py b/testing/compare_keys.py index 021178b..bf4f952 100644 --- a/testing/compare_keys.py +++ b/testing/compare_keys.py @@ -2,6 +2,7 @@ import argparse import os import torch +from diffusers.loaders import LoraLoaderMixin from safetensors.torch import load_file from collections import OrderedDict import json @@ -63,8 +64,8 @@ keys_in_both.sort() json_data = { "both": keys_in_both, - "state_dict_2": keys_not_in_state_dict_2, - "state_dict_1": keys_not_in_state_dict_1 + "not_in_state_dict_2": keys_not_in_state_dict_2, + "not_in_state_dict_1": keys_not_in_state_dict_1 } json_data = json.dumps(json_data, indent=4) @@ -84,6 +85,15 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) json_save_path = os.path.join(project_root, 'config', 'keys.json') json_matched_save_path = os.path.join(project_root, 'config', 'matched.json') json_duped_save_path = os.path.join(project_root, 'config', 'duped.json') +state_dict_1_filename = os.path.basename(args.file_1[0]) +state_dict_2_filename = os.path.basename(args.file_2[0]) +# save key names for each in own file +with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_1_keys, indent=4)) + +with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f: + f.write(json.dumps(state_dict_2_keys, indent=4)) + with open(json_save_path, 'w') as f: f.write(json_data) \ No newline at end of file diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index a41d354..01533d4 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -607,6 +607,24 @@ class StableDiffusion: return embedding_list, latent_list + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): state_dict = {}