Fixed some mismatched weights by adjusting tolerance. The mismatch ironically made the models better lol

This commit is contained in:
Jaret Burkett
2023-08-29 15:20:03 -06:00
parent 14ff51ceb4
commit 836fee47a6
8 changed files with 265 additions and 163 deletions

View File

@@ -93,7 +93,7 @@ total_keys = len(ldm_dict_keys)
matched_ldm_keys = []
matched_diffusers_keys = []
error_margin = 1e-4
error_margin = 1e-6
tmp_merge_key = "TMP___MERGE"

View File

@@ -3,6 +3,8 @@ import os
# add project root to sys path
import sys
from tqdm import tqdm
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
@@ -20,6 +22,8 @@ from toolkit.stable_diffusion_model import StableDiffusion
# you probably wont need this. Unless they change them.... again... again
# on second thought, you probably will
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
device = torch.device('cpu')
dtype = torch.float32
@@ -109,7 +113,26 @@ keys_in_both.sort()
if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0:
print("All keys match!")
exit(0)
print("Checking values...")
mismatch_keys = []
loss = torch.nn.MSELoss()
tolerance = 1e-6
for key in tqdm(keys_in_both):
if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance:
print(f"Values for key {key} don't match!")
print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}")
mismatch_keys.append(key)
if len(mismatch_keys) == 0:
print("All values match!")
else:
print("Some valued font match!")
print(mismatch_keys)
mismatched_path = os.path.join(project_root, 'config', 'mismatch.json')
with open(mismatched_path, 'w') as f:
f.write(json.dumps(mismatch_keys, indent=4))
exit(0)
else:
print("Keys don't match!, generating info...")
@@ -132,17 +155,17 @@ for key in keys_not_in_state_dict_2:
# print(json_data)
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
json_save_path = os.path.join(project_root, 'config', 'keys.json')
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
state_dict_1_filename = os.path.basename(args.file_1[0])
state_dict_2_filename = os.path.basename(args.file_2[0])
# state_dict_2_filename = os.path.basename(args.file_2[0])
# save key names for each in own file
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
f.write(json.dumps(state_dict_1_keys, indent=4))
with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f:
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f:
f.write(json.dumps(state_dict_2_keys, indent=4))
with open(json_save_path, 'w') as f: