mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
sd3.5 integration
This commit is contained in:
@@ -19,11 +19,12 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -107,7 +108,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
@@ -116,6 +117,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
if cls_name == 'SD3Transformer2DModel':
|
||||
from modules.models.sd35.mmditx import MMDiTX
|
||||
model_loader = lambda c: MMDiTX(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||
@@ -170,7 +174,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
return None
|
||||
|
||||
|
||||
def replace_state_dict(sd, asd, guess):
|
||||
def replace_state_dict(sd, asd, guess, is_clip_g = False):
|
||||
vae_key_prefix = guess.vae_key_prefix[0]
|
||||
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
|
||||
|
||||
@@ -210,11 +214,18 @@ def replace_state_dict(sd, asd, guess):
|
||||
sd[vae_key_prefix + k] = v
|
||||
|
||||
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||
if is_clip_g:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_g.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_g.transformer.{k}"] = v
|
||||
else:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
|
||||
for k in keys_to_delete:
|
||||
del sd[k]
|
||||
for k, v in asd.items():
|
||||
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
|
||||
|
||||
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
|
||||
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
|
||||
@@ -241,8 +252,9 @@ def split_state_dict(sd, additional_state_dicts: list = None):
|
||||
|
||||
if isinstance(additional_state_dicts, list):
|
||||
for asd in additional_state_dicts:
|
||||
is_clip_g = 'clip_g' in asd
|
||||
asd = load_torch_file(asd)
|
||||
sd = replace_state_dict(sd, asd, guess)
|
||||
sd = replace_state_dict(sd, asd, guess, is_clip_g)
|
||||
|
||||
guess.clip_target = guess.clip_target(sd)
|
||||
guess.model_type = guess.model_type(sd)
|
||||
|
||||
Reference in New Issue
Block a user