mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-28 18:21:48 +00:00
Add SDXL refiner model (#2686)
add sdxlrefiner adjust some Settings custom CLIP-G support
This commit is contained in:
@@ -19,11 +19,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -222,6 +222,8 @@ def replace_state_dict(sd, asd, guess):
|
||||
model_type = "sd1"
|
||||
case 1024:
|
||||
model_type = "sd2"
|
||||
case 1280:
|
||||
model_type = "xlrf" # sdxl refiner model
|
||||
case 2048:
|
||||
model_type = "sdxl"
|
||||
elif flux_test_key in sd:
|
||||
@@ -234,6 +236,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
"-" : None,
|
||||
"sd1" : "cond_stage_model.transformer.",
|
||||
"sd2" : None,
|
||||
"xlrf": None,
|
||||
"sdxl": "conditioner.embedders.0.transformer.",
|
||||
"flux": "text_encoders.clip_l.transformer.",
|
||||
"sd3" : "text_encoders.clip_l.transformer.",
|
||||
@@ -243,6 +246,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
"-" : None,
|
||||
"sd1" : None,
|
||||
"sd2" : None,
|
||||
"xlrf": "conditioner.embedders.0.model.",
|
||||
"sdxl": "conditioner.embedders.1.model.",
|
||||
"flux": None,
|
||||
"sd3" : "text_encoders.clip_g.",
|
||||
@@ -252,6 +256,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
"-" : None,
|
||||
"sd1" : None,
|
||||
"sd2" : "conditioner.embedders.0.model.",
|
||||
"xlrf": None,
|
||||
"sdxl": None,
|
||||
"flux": None,
|
||||
"sd3" : None,
|
||||
@@ -261,7 +266,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
|
||||
if "first_stage_model.decoder.conv_in.weight" in asd:
|
||||
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
|
||||
if model_type == "sd1" or model_type == "sd2" or model_type == "sdxl":
|
||||
if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl":
|
||||
if channels == 4:
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
Reference in New Issue
Block a user