mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Handle conversions back to ldm for saving
This commit is contained in:
112
testing/test_vae_cycle.py
Normal file
112
testing/test_vae_cycle.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from collections import OrderedDict
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm, vae_keys_squished_on_diffusers
|
||||
import json
|
||||
# this was just used to match the vae keys to the diffusers keys
|
||||
# you probably wont need this. Unless they change them.... again... again
|
||||
# on second thought, you probably will
|
||||
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
vae_path = '/mnt/Models/stable-diffusion/models/VAE/vae-ft-mse-840000-ema-pruned/vae-ft-mse-840000-ema-pruned.safetensors'
|
||||
|
||||
find_matches = False
|
||||
|
||||
state_dict_ldm = load_file(vae_path)
|
||||
diffusers_vae = load_vae(vae_path, dtype=torch.float32).to(device)
|
||||
|
||||
ldm_keys = state_dict_ldm.keys()
|
||||
|
||||
matched_keys = {}
|
||||
duplicated_keys = {
|
||||
|
||||
}
|
||||
|
||||
if find_matches:
|
||||
# find values that match with a very low mse
|
||||
for ldm_key in ldm_keys:
|
||||
ldm_value = state_dict_ldm[ldm_key]
|
||||
for diffusers_key in list(diffusers_vae.state_dict().keys()):
|
||||
diffusers_value = diffusers_vae.state_dict()[diffusers_key]
|
||||
if diffusers_key in vae_keys_squished_on_diffusers:
|
||||
diffusers_value = diffusers_value.clone().unsqueeze(-1).unsqueeze(-1)
|
||||
# if they are not same shape, skip
|
||||
if ldm_value.shape != diffusers_value.shape:
|
||||
continue
|
||||
mse = torch.nn.functional.mse_loss(ldm_value, diffusers_value)
|
||||
if mse < 1e-6:
|
||||
if ldm_key in list(matched_keys.keys()):
|
||||
print(f'{ldm_key} already matched to {matched_keys[ldm_key]}')
|
||||
if ldm_key in duplicated_keys:
|
||||
duplicated_keys[ldm_key].append(diffusers_key)
|
||||
else:
|
||||
duplicated_keys[ldm_key] = [diffusers_key]
|
||||
continue
|
||||
matched_keys[ldm_key] = diffusers_key
|
||||
is_matched = True
|
||||
break
|
||||
|
||||
print(f'Found {len(matched_keys)} matches')
|
||||
|
||||
dif_to_ldm_state_dict = convert_diffusers_back_to_ldm(diffusers_vae)
|
||||
dif_to_ldm_state_dict_keys = list(dif_to_ldm_state_dict.keys())
|
||||
keys_in_both = []
|
||||
|
||||
keys_not_in_diffusers = []
|
||||
for key in ldm_keys:
|
||||
if key not in dif_to_ldm_state_dict_keys:
|
||||
keys_not_in_diffusers.append(key)
|
||||
|
||||
keys_not_in_ldm = []
|
||||
for key in dif_to_ldm_state_dict_keys:
|
||||
if key not in ldm_keys:
|
||||
keys_not_in_ldm.append(key)
|
||||
|
||||
keys_in_both = []
|
||||
for key in ldm_keys:
|
||||
if key in dif_to_ldm_state_dict_keys:
|
||||
keys_in_both.append(key)
|
||||
|
||||
# sort them
|
||||
keys_not_in_diffusers.sort()
|
||||
keys_not_in_ldm.sort()
|
||||
keys_in_both.sort()
|
||||
|
||||
# print(f'Keys in LDM but not in Diffusers: {len(keys_not_in_diffusers)}{keys_not_in_diffusers}')
|
||||
# print(f'Keys in Diffusers but not in LDM: {len(keys_not_in_ldm)}{keys_not_in_ldm}')
|
||||
# print(f'Keys in both: {len(keys_in_both)}{keys_in_both}')
|
||||
|
||||
json_data = {
|
||||
"both": keys_in_both,
|
||||
"ldm": keys_not_in_diffusers,
|
||||
"diffusers": keys_not_in_ldm
|
||||
}
|
||||
json_data = json.dumps(json_data, indent=4)
|
||||
|
||||
remaining_diffusers_values = OrderedDict()
|
||||
for key in keys_not_in_ldm:
|
||||
remaining_diffusers_values[key] = dif_to_ldm_state_dict[key]
|
||||
|
||||
# print(remaining_diffusers_values.keys())
|
||||
|
||||
remaining_ldm_values = OrderedDict()
|
||||
for key in keys_not_in_diffusers:
|
||||
remaining_ldm_values[key] = state_dict_ldm[key]
|
||||
|
||||
# print(json_data)
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
json_save_path = os.path.join(project_root, 'config', 'keys.json')
|
||||
json_matched_save_path = os.path.join(project_root, 'config', 'matched.json')
|
||||
json_duped_save_path = os.path.join(project_root, 'config', 'duped.json')
|
||||
|
||||
with open(json_save_path, 'w') as f:
|
||||
f.write(json_data)
|
||||
if find_matches:
|
||||
with open(json_matched_save_path, 'w') as f:
|
||||
f.write(json.dumps(matched_keys, indent=4))
|
||||
with open(json_duped_save_path, 'w') as f:
|
||||
f.write(json.dumps(duplicated_keys, indent=4))
|
||||
Reference in New Issue
Block a user