mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-11 02:19:59 +00:00
SD3+ (#2688)
Co-authored-by: graemeniedermayer graemeniedermayer@users.noreply.github.com
This commit is contained in:
@@ -20,10 +20,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
||||
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -107,15 +108,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
elif cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
elif cls_name == 'SD3Transformer2DModel':
|
||||
from backend.nn.mmditx import MMDiTX
|
||||
model_loader = lambda c: MMDiTX(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||
@@ -246,10 +250,10 @@ def replace_state_dict(sd, asd, guess):
|
||||
"-" : None,
|
||||
"sd1" : None,
|
||||
"sd2" : None,
|
||||
"xlrf": "conditioner.embedders.0.model.",
|
||||
"sdxl": "conditioner.embedders.1.model.",
|
||||
"xlrf": "conditioner.embedders.0.model.transformer.",
|
||||
"sdxl": "conditioner.embedders.1.model.transformer.",
|
||||
"flux": None,
|
||||
"sd3" : "text_encoders.clip_g.",
|
||||
"sd3" : "text_encoders.clip_g.transformer.",
|
||||
}
|
||||
## prefixes used by various model types for CLIP-H
|
||||
prefix_H = {
|
||||
@@ -292,10 +296,10 @@ def replace_state_dict(sd, asd, guess):
|
||||
|
||||
## CLIP-G
|
||||
CLIP_G = { # key to identify source model old_prefix
|
||||
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.',
|
||||
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.',
|
||||
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.',
|
||||
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.',
|
||||
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||
'transformer.resblocks.0.ln_1.bias' : ''
|
||||
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
||||
}
|
||||
for CLIP_key in CLIP_G.keys():
|
||||
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
|
||||
@@ -303,7 +307,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
old_prefix = CLIP_G[CLIP_key]
|
||||
|
||||
if new_prefix is not None:
|
||||
if "resblocks" not in CLIP_key: # need to convert
|
||||
if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert
|
||||
def convert_transformers(statedict, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
|
||||
@@ -320,15 +324,15 @@ def replace_state_dict(sd, asd, guess):
|
||||
"self_attn.out_proj" : "attn.out_proj" ,
|
||||
}
|
||||
|
||||
for x in keys_to_replace:
|
||||
for x in keys_to_replace: # remove trailing 'transformer.' from new prefix
|
||||
k = x.format(prefix_from)
|
||||
statedict[keys_to_replace[x].format(prefix_to)] = statedict.pop(k)
|
||||
statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k)
|
||||
|
||||
for resblock in range(number):
|
||||
for y in ["weight", "bias"]:
|
||||
for x in resblock_to_replace:
|
||||
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}transformer.resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
statedict[k_to] = statedict.pop(k)
|
||||
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
|
||||
@@ -338,14 +342,16 @@ def replace_state_dict(sd, asd, guess):
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
|
||||
weightsV = statedict.pop(k_from)
|
||||
|
||||
k_to = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
||||
k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
||||
|
||||
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
|
||||
return statedict
|
||||
|
||||
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
|
||||
new_prefix = ""
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
if old_prefix == "":
|
||||
elif old_prefix == "":
|
||||
for k, v in asd.items():
|
||||
new_k = new_prefix + k
|
||||
sd[new_k] = v
|
||||
@@ -360,7 +366,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
|
||||
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
|
||||
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||
'transformer.resblocks.0.ln_1.bias' : ''
|
||||
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
||||
}
|
||||
|
||||
for CLIP_key in CLIP_L.keys():
|
||||
@@ -376,6 +382,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
|
||||
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
|
||||
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
|
||||
"text_projection" : "text_projection.weight",
|
||||
}
|
||||
resblock_to_replace = {
|
||||
"ln_1" : "layer_norm1",
|
||||
@@ -391,11 +398,11 @@ def replace_state_dict(sd, asd, guess):
|
||||
for resblock in range(number):
|
||||
for y in ["weight", "bias"]:
|
||||
for x in resblock_to_replace:
|
||||
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
statedict[k_to] = statedict.pop(k)
|
||||
|
||||
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
weights = statedict.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
@@ -405,9 +412,10 @@ def replace_state_dict(sd, asd, guess):
|
||||
return statedict
|
||||
|
||||
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
|
||||
new_prefix = ""
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
if old_prefix == "":
|
||||
elif old_prefix == "":
|
||||
for k, v in asd.items():
|
||||
new_k = new_prefix + k
|
||||
sd[new_k] = v
|
||||
|
||||
Reference in New Issue
Block a user