improve vae key mapping

This commit is contained in:
layerdiffusion
2024-07-30 09:23:58 -06:00
parent 3289ccb53f
commit 40dd61ba6c
4 changed files with 23 additions and 11 deletions

View File

@@ -24,16 +24,20 @@ def split_state_dict_with_prefix(sd, prefix):
return vae_sd
def shrink_last_key(t):
ts = t.split('.')
del ts[-1]
return '.'.join(ts)
def compile_state_dict(state_dict):
sd = {}
mapping = {}
for k, v in state_dict.items():
sd[k] = v.value
mapping[shrink_last_key(v.key)] = shrink_last_key(k)
mapping[v.key] = (k, v.advanced_indexing)
return sd, mapping
def map_state_dict(sd, mapping):
new_sd = {}
for k, v in sd.items():
k, indexing = mapping.get(k, (k, None))
if indexing is not None:
v = v[indexing]
new_sd[k] = v
return new_sd