Add SDXL refiner model (#2686)

add sdxlrefiner
adjust some Settings
custom CLIP-G support
This commit is contained in:
DenOfEquity
2025-02-25 10:49:47 +00:00
committed by GitHub
parent c4b6fccefc
commit 8dd92501e6
9 changed files with 134 additions and 17 deletions

View File

@@ -19,11 +19,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
from backend.diffusion_engine.sd15 import StableDiffusion
from backend.diffusion_engine.sd20 import StableDiffusion2
from backend.diffusion_engine.sdxl import StableDiffusionXL
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
from backend.diffusion_engine.flux import Flux
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux]
logging.getLogger("diffusers").setLevel(logging.ERROR)
@@ -222,6 +222,8 @@ def replace_state_dict(sd, asd, guess):
model_type = "sd1"
case 1024:
model_type = "sd2"
case 1280:
model_type = "xlrf" # sdxl refiner model
case 2048:
model_type = "sdxl"
elif flux_test_key in sd:
@@ -234,6 +236,7 @@ def replace_state_dict(sd, asd, guess):
"-" : None,
"sd1" : "cond_stage_model.transformer.",
"sd2" : None,
"xlrf": None,
"sdxl": "conditioner.embedders.0.transformer.",
"flux": "text_encoders.clip_l.transformer.",
"sd3" : "text_encoders.clip_l.transformer.",
@@ -243,6 +246,7 @@ def replace_state_dict(sd, asd, guess):
"-" : None,
"sd1" : None,
"sd2" : None,
"xlrf": "conditioner.embedders.0.model.",
"sdxl": "conditioner.embedders.1.model.",
"flux": None,
"sd3" : "text_encoders.clip_g.",
@@ -252,6 +256,7 @@ def replace_state_dict(sd, asd, guess):
"-" : None,
"sd1" : None,
"sd2" : "conditioner.embedders.0.model.",
"xlrf": None,
"sdxl": None,
"flux": None,
"sd3" : None,
@@ -261,7 +266,7 @@ def replace_state_dict(sd, asd, guess):
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
if "first_stage_model.decoder.conv_in.weight" in asd:
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
if model_type == "sd1" or model_type == "sd2" or model_type == "sdxl":
if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl":
if channels == 4:
for k, v in asd.items():
sd[k] = v