sd3.5 integration

This commit is contained in:
grae
2024-10-25 18:39:45 -06:00
parent d4d8ad406e
commit 1363999fb1
32 changed files with 332140 additions and 11 deletions

View File

@@ -19,11 +19,12 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sd35 import StableDiffusion3
from backend.diffusion_engine.sdxl import StableDiffusionXL
from backend.diffusion_engine.flux import Flux
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, StableDiffusion3, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -107,7 +108,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
return model
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
model_loader = None
@@ -116,6 +117,9 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
if cls_name == 'FluxTransformer2DModel':
from backend.nn.flux import IntegratedFluxTransformer2DModel
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
if cls_name == 'SD3Transformer2DModel':
from modules.models.sd35.mmditx import MMDiTX
model_loader = lambda c: MMDiTX(**c)
unet_config = guess.unet_config.copy()
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
@@ -170,7 +174,7 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
return None
def replace_state_dict(sd, asd, guess):
def replace_state_dict(sd, asd, guess, is_clip_g = False):
vae_key_prefix = guess.vae_key_prefix[0]
text_encoder_key_prefix = guess.text_encoder_key_prefix[0]
@@ -210,11 +214,18 @@ def replace_state_dict(sd, asd, guess):
sd[vae_key_prefix + k] = v
if 'text_model.encoder.layers.0.layer_norm1.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
if is_clip_g:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_g.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_g.transformer.{k}"] = v
else:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}clip_l.")]
for k in keys_to_delete:
del sd[k]
for k, v in asd.items():
sd[f"{text_encoder_key_prefix}clip_l.transformer.{k}"] = v
if 'encoder.block.0.layer.0.SelfAttention.k.weight' in asd:
keys_to_delete = [k for k in sd if k.startswith(f"{text_encoder_key_prefix}t5xxl.")]
@@ -241,8 +252,9 @@ def split_state_dict(sd, additional_state_dicts: list = None):
if isinstance(additional_state_dicts, list):
for asd in additional_state_dicts:
is_clip_g = 'clip_g' in asd
asd = load_torch_file(asd)
sd = replace_state_dict(sd, asd, guess)
sd = replace_state_dict(sd, asd, guess, is_clip_g)
guess.clip_target = guess.clip_target(sd)
guess.model_type = guess.model_type(sd)