mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
Add SDXL refiner model (#2686)
add sdxlrefiner adjust some Settings custom CLIP-G support
This commit is contained in:
@@ -13,6 +13,8 @@ from backend.nn.unet import Timestep
|
|||||||
import safetensors.torch as sf
|
import safetensors.torch as sf
|
||||||
from backend import utils
|
from backend import utils
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXL(ForgeDiffusionEngine):
|
class StableDiffusionXL(ForgeDiffusionEngine):
|
||||||
matched_guesses = [model_list.SDXL]
|
matched_guesses = [model_list.SDXL]
|
||||||
@@ -91,8 +93,8 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
|||||||
height = getattr(prompt, 'height', 1024) or 1024
|
height = getattr(prompt, 'height', 1024) or 1024
|
||||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||||
|
|
||||||
crop_w = 0
|
crop_w = opts.sdxl_crop_left
|
||||||
crop_h = 0
|
crop_h = opts.sdxl_crop_top
|
||||||
target_width = width
|
target_width = width
|
||||||
target_height = height
|
target_height = height
|
||||||
|
|
||||||
@@ -150,3 +152,121 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
|||||||
)
|
)
|
||||||
sf.save_file(sd, filename)
|
sf.save_file(sd, filename)
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXLRefiner(ForgeDiffusionEngine):
|
||||||
|
matched_guesses = [model_list.SDXLRefiner]
|
||||||
|
|
||||||
|
def __init__(self, estimated_config, huggingface_components):
|
||||||
|
super().__init__(estimated_config, huggingface_components)
|
||||||
|
|
||||||
|
clip = CLIP(
|
||||||
|
model_dict={
|
||||||
|
'clip_g': huggingface_components['text_encoder']
|
||||||
|
},
|
||||||
|
tokenizer_dict={
|
||||||
|
'clip_g': huggingface_components['tokenizer'],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
vae = VAE(model=huggingface_components['vae'])
|
||||||
|
|
||||||
|
unet = UnetPatcher.from_model(
|
||||||
|
model=huggingface_components['unet'],
|
||||||
|
diffusers_scheduler=huggingface_components['scheduler'],
|
||||||
|
config=estimated_config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||||
|
text_encoder=clip.cond_stage_model.clip_g,
|
||||||
|
tokenizer=clip.tokenizer.clip_g,
|
||||||
|
embedding_dir=dynamic_args['embedding_dir'],
|
||||||
|
embedding_key='clip_g',
|
||||||
|
embedding_expected_shape=2048,
|
||||||
|
emphasis_name=dynamic_args['emphasis_name'],
|
||||||
|
text_projection=True,
|
||||||
|
minimal_clip_skip=2,
|
||||||
|
clip_skip=2,
|
||||||
|
return_pooled=True,
|
||||||
|
final_layer_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embedder = Timestep(256)
|
||||||
|
|
||||||
|
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||||
|
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||||
|
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||||
|
|
||||||
|
# WebUI Legacy
|
||||||
|
self.is_sdxl = True
|
||||||
|
|
||||||
|
def set_clip_skip(self, clip_skip):
|
||||||
|
self.text_processing_engine_g.clip_skip = clip_skip
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_learned_conditioning(self, prompt: list[str]):
|
||||||
|
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||||
|
|
||||||
|
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
|
||||||
|
|
||||||
|
width = getattr(prompt, 'width', 1024) or 1024
|
||||||
|
height = getattr(prompt, 'height', 1024) or 1024
|
||||||
|
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||||
|
|
||||||
|
crop_w = opts.sdxl_crop_left
|
||||||
|
crop_h = opts.sdxl_crop_top
|
||||||
|
aesthetic = opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
|
||||||
|
out = [
|
||||||
|
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
|
||||||
|
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
|
||||||
|
self.embedder(torch.Tensor([aesthetic]))
|
||||||
|
]
|
||||||
|
|
||||||
|
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
|
||||||
|
|
||||||
|
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||||
|
|
||||||
|
if force_zero_negative_prompt:
|
||||||
|
clip_pooled = torch.zeros_like(clip_pooled)
|
||||||
|
cond_g = torch.zeros_like(cond_g)
|
||||||
|
|
||||||
|
cond = dict(
|
||||||
|
crossattn=cond_g,
|
||||||
|
vector=torch.cat([clip_pooled, flat], dim=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cond
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def get_prompt_lengths_on_ui(self, prompt):
|
||||||
|
_, token_count = self.text_processing_engine_g.process_texts([prompt])
|
||||||
|
return token_count, self.text_processing_engine_g.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def encode_first_stage(self, x):
|
||||||
|
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||||
|
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||||
|
return sample.to(x)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def decode_first_stage(self, x):
|
||||||
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||||
|
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||||
|
return sample.to(x)
|
||||||
|
|
||||||
|
def save_checkpoint(self, filename):
|
||||||
|
sd = {}
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
model_list.SDXLRefiner.process_clip_state_dict_for_saving(self,
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||||
|
)
|
||||||
|
sf.save_file(sd, filename)
|
||||||
|
return filename
|
||||||
|
|||||||
@@ -9,18 +9,10 @@
|
|||||||
"EulerDiscreteScheduler"
|
"EulerDiscreteScheduler"
|
||||||
],
|
],
|
||||||
"text_encoder": [
|
"text_encoder": [
|
||||||
null,
|
|
||||||
null
|
|
||||||
],
|
|
||||||
"text_encoder_2": [
|
|
||||||
"transformers",
|
"transformers",
|
||||||
"CLIPTextModelWithProjection"
|
"CLIPTextModelWithProjection"
|
||||||
],
|
],
|
||||||
"tokenizer": [
|
"tokenizer": [
|
||||||
null,
|
|
||||||
null
|
|
||||||
],
|
|
||||||
"tokenizer_2": [
|
|
||||||
"transformers",
|
"transformers",
|
||||||
"CLIPTokenizer"
|
"CLIPTokenizer"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
|||||||
|
|
||||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||||
from backend.diffusion_engine.sdxl import StableDiffusionXL
|
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
||||||
from backend.diffusion_engine.flux import Flux
|
from backend.diffusion_engine.flux import Flux
|
||||||
|
|
||||||
|
|
||||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux]
|
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux]
|
||||||
|
|
||||||
|
|
||||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||||
@@ -222,6 +222,8 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
model_type = "sd1"
|
model_type = "sd1"
|
||||||
case 1024:
|
case 1024:
|
||||||
model_type = "sd2"
|
model_type = "sd2"
|
||||||
|
case 1280:
|
||||||
|
model_type = "xlrf" # sdxl refiner model
|
||||||
case 2048:
|
case 2048:
|
||||||
model_type = "sdxl"
|
model_type = "sdxl"
|
||||||
elif flux_test_key in sd:
|
elif flux_test_key in sd:
|
||||||
@@ -234,6 +236,7 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
"-" : None,
|
"-" : None,
|
||||||
"sd1" : "cond_stage_model.transformer.",
|
"sd1" : "cond_stage_model.transformer.",
|
||||||
"sd2" : None,
|
"sd2" : None,
|
||||||
|
"xlrf": None,
|
||||||
"sdxl": "conditioner.embedders.0.transformer.",
|
"sdxl": "conditioner.embedders.0.transformer.",
|
||||||
"flux": "text_encoders.clip_l.transformer.",
|
"flux": "text_encoders.clip_l.transformer.",
|
||||||
"sd3" : "text_encoders.clip_l.transformer.",
|
"sd3" : "text_encoders.clip_l.transformer.",
|
||||||
@@ -243,6 +246,7 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
"-" : None,
|
"-" : None,
|
||||||
"sd1" : None,
|
"sd1" : None,
|
||||||
"sd2" : None,
|
"sd2" : None,
|
||||||
|
"xlrf": "conditioner.embedders.0.model.",
|
||||||
"sdxl": "conditioner.embedders.1.model.",
|
"sdxl": "conditioner.embedders.1.model.",
|
||||||
"flux": None,
|
"flux": None,
|
||||||
"sd3" : "text_encoders.clip_g.",
|
"sd3" : "text_encoders.clip_g.",
|
||||||
@@ -252,6 +256,7 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
"-" : None,
|
"-" : None,
|
||||||
"sd1" : None,
|
"sd1" : None,
|
||||||
"sd2" : "conditioner.embedders.0.model.",
|
"sd2" : "conditioner.embedders.0.model.",
|
||||||
|
"xlrf": None,
|
||||||
"sdxl": None,
|
"sdxl": None,
|
||||||
"flux": None,
|
"flux": None,
|
||||||
"sd3" : None,
|
"sd3" : None,
|
||||||
@@ -261,7 +266,7 @@ def replace_state_dict(sd, asd, guess):
|
|||||||
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
|
## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3).
|
||||||
if "first_stage_model.decoder.conv_in.weight" in asd:
|
if "first_stage_model.decoder.conv_in.weight" in asd:
|
||||||
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
|
channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1]
|
||||||
if model_type == "sd1" or model_type == "sd2" or model_type == "sdxl":
|
if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl":
|
||||||
if channels == 4:
|
if channels == 4:
|
||||||
for k, v in asd.items():
|
for k, v in asd.items():
|
||||||
sd[k] = v
|
sd[k] = v
|
||||||
|
|||||||
@@ -188,10 +188,10 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
|
||||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}),
|
||||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}),
|
||||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model negative prompt"),
|
||||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model prompt"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
|
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
|
||||||
|
|||||||
Reference in New Issue
Block a user