mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added method to get specific keys from model
This commit is contained in:
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers.loaders import LoraLoaderMixin
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import json
|
import json
|
||||||
@@ -63,8 +64,8 @@ keys_in_both.sort()
|
|||||||
|
|
||||||
json_data = {
|
json_data = {
|
||||||
"both": keys_in_both,
|
"both": keys_in_both,
|
||||||
"state_dict_2": keys_not_in_state_dict_2,
|
"not_in_state_dict_2": keys_not_in_state_dict_2,
|
||||||
"state_dict_1": keys_not_in_state_dict_1
|
"not_in_state_dict_1": keys_not_in_state_dict_1
|
||||||
}
|
}
|
||||||
json_data = json.dumps(json_data, indent=4)
|
json_data = json.dumps(json_data, indent=4)
|
||||||
|
|
||||||
@@ -84,6 +85,15 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||||||
json_save_path = os.path.join(project_root, 'config', 'keys.json')
|
json_save_path = os.path.join(project_root, 'config', 'keys.json')
|
||||||
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
|
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
|
||||||
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
|
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
|
||||||
|
state_dict_1_filename = os.path.basename(args.file_1[0])
|
||||||
|
state_dict_2_filename = os.path.basename(args.file_2[0])
|
||||||
|
# save key names for each in own file
|
||||||
|
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
|
||||||
|
f.write(json.dumps(state_dict_1_keys, indent=4))
|
||||||
|
|
||||||
|
with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f:
|
||||||
|
f.write(json.dumps(state_dict_2_keys, indent=4))
|
||||||
|
|
||||||
|
|
||||||
with open(json_save_path, 'w') as f:
|
with open(json_save_path, 'w') as f:
|
||||||
f.write(json_data)
|
f.write(json_data)
|
||||||
@@ -607,6 +607,24 @@ class StableDiffusion:
|
|||||||
|
|
||||||
return embedding_list, latent_list
|
return embedding_list, latent_list
|
||||||
|
|
||||||
|
def get_weight_by_name(self, name):
|
||||||
|
# weights begin with te{te_num}_ for text encoder
|
||||||
|
# weights begin with unet_ for unet_
|
||||||
|
if name.startswith('te'):
|
||||||
|
key = name[4:]
|
||||||
|
# text encoder
|
||||||
|
te_num = int(name[2])
|
||||||
|
if isinstance(self.text_encoder, list):
|
||||||
|
return self.text_encoder[te_num].state_dict()[key]
|
||||||
|
else:
|
||||||
|
return self.text_encoder.state_dict()[key]
|
||||||
|
elif name.startswith('unet'):
|
||||||
|
key = name[5:]
|
||||||
|
# unet
|
||||||
|
return self.unet.state_dict()[key]
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown weight name: {name}")
|
||||||
|
|
||||||
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user