Co-authored-by: graemeniedermayer graemeniedermayer@users.noreply.github.com
This commit is contained in:
DenOfEquity
2025-02-27 17:54:44 +00:00
committed by GitHub
parent 8dd92501e6
commit f23bc80d2f
6 changed files with 1184 additions and 22 deletions

View File

@@ -20,10 +20,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
from backend.diffusion_engine.sd35 import StableDiffusion3
from backend.diffusion_engine.flux import Flux
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -107,15 +108,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
model_loader = None
if cls_name == 'UNet2DConditionModel':
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
if cls_name == 'FluxTransformer2DModel':
elif cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
elif cls_name == 'SD3Transformer2DModel':
from backend.nn.mmditx import MMDiTX
model_loader = lambda c: MMDiTX(**c)
unet_config = guess.unet_config.copy()
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
@@ -246,10 +250,10 @@ def replace_state_dict(sd, asd, guess):
"-" : None,
"sd1" : None,
"sd2" : None,
"xlrf": "conditioner.embedders.0.model.",
"sdxl": "conditioner.embedders.1.model.",
"xlrf": "conditioner.embedders.0.model.transformer.",
"sdxl": "conditioner.embedders.1.model.transformer.",
"flux": None,
"sd3" : "text_encoders.clip_g.",
"sd3" : "text_encoders.clip_g.transformer.",
}
## prefixes used by various model types for CLIP-H
prefix_H = {
@@ -292,10 +296,10 @@ def replace_state_dict(sd, asd, guess):
## CLIP-G
CLIP_G = { # key to identify source model old_prefix
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.',
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.',
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.',
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.',
'text_model.encoder.layers.0.layer_norm1.bias' : '',
'transformer.resblocks.0.ln_1.bias' : ''
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
}
for CLIP_key in CLIP_G.keys():
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
@@ -303,7 +307,7 @@ def replace_state_dict(sd, asd, guess):
old_prefix = CLIP_G[CLIP_key]
if new_prefix is not None:
if "resblocks" not in CLIP_key: # need to convert
if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert
def convert_transformers(statedict, prefix_from, prefix_to, number):
keys_to_replace = {
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
@@ -320,15 +324,15 @@ def replace_state_dict(sd, asd, guess):
"self_attn.out_proj" : "attn.out_proj" ,
}
for x in keys_to_replace:
for x in keys_to_replace: # remove trailing 'transformer.' from new prefix
k = x.format(prefix_from)
statedict[keys_to_replace[x].format(prefix_to)] = statedict.pop(k)
statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k)
for resblock in range(number):
for y in ["weight", "bias"]:
for x in resblock_to_replace:
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}transformer.resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
statedict[k_to] = statedict.pop(k)
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
@@ -338,14 +342,16 @@ def replace_state_dict(sd, asd, guess):
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
weightsV = statedict.pop(k_from)
k_to = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
return statedict
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
new_prefix = ""
for k, v in asd.items():
sd[k] = v
if old_prefix == "":
elif old_prefix == "":
for k, v in asd.items():
new_k = new_prefix + k
sd[new_k] = v
@@ -360,7 +366,7 @@ def replace_state_dict(sd, asd, guess):
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
'text_model.encoder.layers.0.layer_norm1.bias' : '',
'transformer.resblocks.0.ln_1.bias' : ''
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
}
for CLIP_key in CLIP_L.keys():
@@ -376,6 +382,7 @@ def replace_state_dict(sd, asd, guess):
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
"text_projection" : "text_projection.weight",
}
resblock_to_replace = {
"ln_1" : "layer_norm1",
@@ -391,11 +398,11 @@ def replace_state_dict(sd, asd, guess):
for resblock in range(number):
for y in ["weight", "bias"]:
for x in resblock_to_replace:
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
statedict[k_to] = statedict.pop(k)
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
weights = statedict.pop(k_from)
shape_from = weights.shape[0] // 3
for x in range(3):
@@ -405,9 +412,10 @@ def replace_state_dict(sd, asd, guess):
return statedict
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
new_prefix = ""
for k, v in asd.items():
sd[k] = v
if old_prefix == "":
elif old_prefix == "":
for k, v in asd.items():
new_k = new_prefix + k
sd[new_k] = v