mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 06:29:48 +00:00
Fixed some mismatched weights by adjusting tolerance. The mismatch ironically made the models better lol
This commit is contained in:
@@ -3,6 +3,8 @@ import os
|
||||
# add project root to sys path
|
||||
import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import torch
|
||||
@@ -20,6 +22,8 @@ from toolkit.stable_diffusion_model import StableDiffusion
|
||||
# you probably wont need this. Unless they change them.... again... again
|
||||
# on second thought, you probably will
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
@@ -109,7 +113,26 @@ keys_in_both.sort()
|
||||
|
||||
if len(keys_not_in_state_dict_2) == 0 and len(keys_not_in_state_dict_1) == 0:
|
||||
print("All keys match!")
|
||||
exit(0)
|
||||
print("Checking values...")
|
||||
mismatch_keys = []
|
||||
loss = torch.nn.MSELoss()
|
||||
tolerance = 1e-6
|
||||
for key in tqdm(keys_in_both):
|
||||
if loss(state_dict_file_1[key], state_dict_file_2[key]) > tolerance:
|
||||
print(f"Values for key {key} don't match!")
|
||||
print(f"Loss: {loss(state_dict_file_1[key], state_dict_file_2[key])}")
|
||||
mismatch_keys.append(key)
|
||||
|
||||
if len(mismatch_keys) == 0:
|
||||
print("All values match!")
|
||||
else:
|
||||
print("Some valued font match!")
|
||||
print(mismatch_keys)
|
||||
mismatched_path = os.path.join(project_root, 'config', 'mismatch.json')
|
||||
with open(mismatched_path, 'w') as f:
|
||||
f.write(json.dumps(mismatch_keys, indent=4))
|
||||
exit(0)
|
||||
|
||||
else:
|
||||
print("Keys don't match!, generating info...")
|
||||
|
||||
@@ -132,17 +155,17 @@ for key in keys_not_in_state_dict_2:
|
||||
|
||||
# print(json_data)
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
json_save_path = os.path.join(project_root, 'config', 'keys.json')
|
||||
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
|
||||
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
|
||||
state_dict_1_filename = os.path.basename(args.file_1[0])
|
||||
state_dict_2_filename = os.path.basename(args.file_2[0])
|
||||
# state_dict_2_filename = os.path.basename(args.file_2[0])
|
||||
# save key names for each in own file
|
||||
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}.json'), 'w') as f:
|
||||
f.write(json.dumps(state_dict_1_keys, indent=4))
|
||||
|
||||
with open(os.path.join(project_root, 'config', f'{state_dict_2_filename}.json'), 'w') as f:
|
||||
with open(os.path.join(project_root, 'config', f'{state_dict_1_filename}_loop.json'), 'w') as f:
|
||||
f.write(json.dumps(state_dict_2_keys, indent=4))
|
||||
|
||||
with open(json_save_path, 'w') as f:
|
||||
|
||||
Reference in New Issue
Block a user