mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 03:01:15 +00:00
increased support for custom CLIPs (#2642)
increased support for custom CLIPs more forms recognised now can be applied to sd1.5, sdxl, (sd3)
This commit is contained in:
@@ -209,12 +209,208 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
for k, v in asd.items():
|
for k, v in asd.items():
|
||||||
sd[vae_key_prefix + k] = v
|
sd[vae_key_prefix + k] = v
|
||||||
|
|
||||||
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
|
||||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
## identify model type
|
||||||
for k in keys_to_delete:
|
flux_test_key = "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale"
|
||||||
del sd[k]
|
sd3_test_key = "model.diffusion_model.final_layer.adaLN_modulation.1.bias"
|
||||||
for k, v in asd.items():
|
legacy_test_key = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
|
||||||
|
model_type = "-"
|
||||||
|
if legacy_test_key in sd:
|
||||||
|
match sd[legacy_test_key].shape[1]:
|
||||||
|
case 768:
|
||||||
|
model_type = "sd1"
|
||||||
|
case 1024:
|
||||||
|
model_type = "sd2"
|
||||||
|
case 2048:
|
||||||
|
model_type = "sdxl"
|
||||||
|
elif flux_test_key in sd:
|
||||||
|
model_type = "flux"
|
||||||
|
elif sd3_test_key in sd:
|
||||||
|
model_type = "sd3"
|
||||||
|
|
||||||
|
## prefixes used by various model types for CLIP-L
|
||||||
|
prefix_L = {
|
||||||
|
"-" : None,
|
||||||
|
"sd1" : "cond_stage_model.transformer.",
|
||||||
|
"sd2" : None,
|
||||||
|
"sdxl": "conditioner.embedders.0.transformer.",
|
||||||
|
"flux": "text_encoders.clip_l.transformer.",
|
||||||
|
"sd3" : "text_encoders.clip_l.transformer.",
|
||||||
|
}
|
||||||
|
## prefixes used by various model types for CLIP-G
|
||||||
|
prefix_G = {
|
||||||
|
"-" : None,
|
||||||
|
"sd1" : None,
|
||||||
|
"sd2" : None,
|
||||||
|
"sdxl": "conditioner.embedders.1.model.",
|
||||||
|
"flux": None,
|
||||||
|
"sd3" : "text_encoders.clip_g.",
|
||||||
|
}
|
||||||
|
## prefixes used by various model types for CLIP-H
|
||||||
|
prefix_H = {
|
||||||
|
"-" : None,
|
||||||
|
"sd1" : None,
|
||||||
|
"sd2" : "conditioner.embedders.0.model.",
|
||||||
|
"sdxl": None,
|
||||||
|
"flux": None,
|
||||||
|
"sd3" : None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
|
||||||
|
if "first_stage_model.decoder.conv_in.weight" in asd:
|
||||||
|
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
|
||||||
|
if model_type == "sd1" or model_type == "sd2" or model_type == "sdxl":
|
||||||
|
if channels == 4:
|
||||||
|
for k, v in asd.items():
|
||||||
|
sd[k] = v
|
||||||
|
elif model_type == "sd3":
|
||||||
|
if channels == 16:
|
||||||
|
for k, v in asd.items():
|
||||||
|
sd[k] = v
|
||||||
|
|
||||||
|
## CLIP-H
|
||||||
|
CLIP_H = { # key to identify source model old_prefix
|
||||||
|
'cond_stage_model.model.ln_final.weight' : 'cond_stage_model.model.',
|
||||||
|
# 'text_model.encoder.layers.0.layer_norm1.bias' : 'text_model'. # would need converting
|
||||||
|
}
|
||||||
|
for CLIP_key in CLIP_H.keys():
|
||||||
|
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1024:
|
||||||
|
new_prefix = prefix_H[model_type]
|
||||||
|
old_prefix = CLIP_H[CLIP_key]
|
||||||
|
|
||||||
|
if new_prefix is not None:
|
||||||
|
for k, v in asd.items():
|
||||||
|
new_k = k.replace(old_prefix, new_prefix)
|
||||||
|
sd[new_k] = v
|
||||||
|
|
||||||
|
## CLIP-G
|
||||||
|
CLIP_G = { # key to identify source model old_prefix
|
||||||
|
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.',
|
||||||
|
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.',
|
||||||
|
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||||
|
'transformer.resblocks.0.ln_1.bias' : ''
|
||||||
|
}
|
||||||
|
for CLIP_key in CLIP_G.keys():
|
||||||
|
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
|
||||||
|
new_prefix = prefix_G[model_type]
|
||||||
|
old_prefix = CLIP_G[CLIP_key]
|
||||||
|
|
||||||
|
if new_prefix is not None:
|
||||||
|
if "resblocks" not in CLIP_key: # need to convert
|
||||||
|
def convert_transformers(statedict, prefix_from, prefix_to, number):
|
||||||
|
keys_to_replace = {
|
||||||
|
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
|
||||||
|
"{}text_model.embeddings.token_embedding.weight" : "{}token_embedding.weight",
|
||||||
|
"{}text_model.final_layer_norm.weight" : "{}ln_final.weight",
|
||||||
|
"{}text_model.final_layer_norm.bias" : "{}ln_final.bias",
|
||||||
|
"text_projection.weight" : "{}text_projection",
|
||||||
|
}
|
||||||
|
resblock_to_replace = {
|
||||||
|
"layer_norm1" : "ln_1",
|
||||||
|
"layer_norm2" : "ln_2",
|
||||||
|
"mlp.fc1" : "mlp.c_fc",
|
||||||
|
"mlp.fc2" : "mlp.c_proj",
|
||||||
|
"self_attn.out_proj" : "attn.out_proj" ,
|
||||||
|
}
|
||||||
|
|
||||||
|
for x in keys_to_replace:
|
||||||
|
k = x.format(prefix_from)
|
||||||
|
statedict[keys_to_replace[x].format(prefix_to)] = statedict.pop(k)
|
||||||
|
|
||||||
|
for resblock in range(number):
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
for x in resblock_to_replace:
|
||||||
|
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||||
|
k_to = "{}transformer.resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||||
|
statedict[k_to] = statedict.pop(k)
|
||||||
|
|
||||||
|
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
|
||||||
|
weightsQ = statedict.pop(k_from)
|
||||||
|
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.k_proj", y)
|
||||||
|
weightsK = statedict.pop(k_from)
|
||||||
|
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
|
||||||
|
weightsV = statedict.pop(k_from)
|
||||||
|
|
||||||
|
k_to = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
||||||
|
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
|
||||||
|
return statedict
|
||||||
|
|
||||||
|
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
|
||||||
|
new_prefix = ""
|
||||||
|
|
||||||
|
if old_prefix == "":
|
||||||
|
for k, v in asd.items():
|
||||||
|
new_k = new_prefix + k
|
||||||
|
sd[new_k] = v
|
||||||
|
else:
|
||||||
|
for k, v in asd.items():
|
||||||
|
new_k = k.replace(old_prefix, new_prefix)
|
||||||
|
sd[new_k] = v
|
||||||
|
|
||||||
|
## CLIP-L
|
||||||
|
CLIP_L = { # key to identify source model old_prefix
|
||||||
|
'cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'cond_stage_model.transformer.',
|
||||||
|
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
|
||||||
|
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
|
||||||
|
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||||
|
'transformer.resblocks.0.ln_1.bias' : ''
|
||||||
|
}
|
||||||
|
|
||||||
|
for CLIP_key in CLIP_L.keys():
|
||||||
|
if CLIP_key in asd and asd[CLIP_key].shape[0] == 768:
|
||||||
|
new_prefix = prefix_L[model_type]
|
||||||
|
old_prefix = CLIP_L[CLIP_key]
|
||||||
|
|
||||||
|
if new_prefix is not None:
|
||||||
|
if "resblocks" in CLIP_key: # need to convert
|
||||||
|
def transformers_convert(statedict, prefix_from, prefix_to, number):
|
||||||
|
keys_to_replace = {
|
||||||
|
"positional_embedding" : "{}text_model.embeddings.position_embedding.weight",
|
||||||
|
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
|
||||||
|
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
|
||||||
|
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
|
||||||
|
}
|
||||||
|
resblock_to_replace = {
|
||||||
|
"ln_1" : "layer_norm1",
|
||||||
|
"ln_2" : "layer_norm2",
|
||||||
|
"mlp.c_fc" : "mlp.fc1",
|
||||||
|
"mlp.c_proj" : "mlp.fc2",
|
||||||
|
"attn.out_proj" : "self_attn.out_proj",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in keys_to_replace:
|
||||||
|
statedict[keys_to_replace[k].format(prefix_to)] = statedict.pop(k)
|
||||||
|
|
||||||
|
for resblock in range(number):
|
||||||
|
for y in ["weight", "bias"]:
|
||||||
|
for x in resblock_to_replace:
|
||||||
|
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||||
|
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||||
|
statedict[k_to] = statedict.pop(k)
|
||||||
|
|
||||||
|
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||||
|
weights = statedict.pop(k_from)
|
||||||
|
shape_from = weights.shape[0] // 3
|
||||||
|
for x in range(3):
|
||||||
|
p = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]
|
||||||
|
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, p[x], y)
|
||||||
|
statedict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||||
|
return statedict
|
||||||
|
|
||||||
|
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
|
||||||
|
new_prefix = ""
|
||||||
|
|
||||||
|
if old_prefix == "":
|
||||||
|
for k, v in asd.items():
|
||||||
|
new_k = new_prefix + k
|
||||||
|
sd[new_k] = v
|
||||||
|
else:
|
||||||
|
for k, v in asd.items():
|
||||||
|
new_k = k.replace(old_prefix, new_prefix)
|
||||||
|
sd[new_k] = v
|
||||||
|
|
||||||
|
|
||||||
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||||
@@ -227,9 +423,8 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_state_dict(sd):
|
def preprocess_state_dict(sd):
|
||||||
if any("double_block" in k for k in sd.keys()):
|
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
||||||
if not any(k.startswith("model.diffusion_model") for k in sd.keys()):
|
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
||||||
sd = {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
|
||||||
|
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
@@ -243,11 +438,14 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
|||||||
for asd in additional_state_dicts:
|
for asd in additional_state_dicts:
|
||||||
asd = load_torch_file(asd)
|
asd = load_torch_file(asd)
|
||||||
sd = replace_state_dict(sd, asd, guess)
|
sd = replace_state_dict(sd, asd, guess)
|
||||||
|
del asd
|
||||||
|
|
||||||
guess.clip_target = guess.clip_target(sd)
|
guess.clip_target = guess.clip_target(sd)
|
||||||
guess.model_type = guess.model_type(sd)
|
guess.model_type = guess.model_type(sd)
|
||||||
guess.ztsnr = 'ztsnr' in sd
|
guess.ztsnr = 'ztsnr' in sd
|
||||||
|
|
||||||
|
sd = guess.process_vae_state_dict(sd)
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
guess.unet_target: try_filter_state_dict(sd, guess.unet_key_prefix),
|
||||||
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
guess.vae_target: try_filter_state_dict(sd, guess.vae_key_prefix)
|
||||||
|
|||||||
Reference in New Issue
Block a user