diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index d12fe195..1c3b2424 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -13,6 +13,8 @@ from backend.nn.unet import Timestep import safetensors.torch as sf from backend import utils +from modules.shared import opts + class StableDiffusionXL(ForgeDiffusionEngine): matched_guesses = [model_list.SDXL] @@ -91,8 +93,8 @@ class StableDiffusionXL(ForgeDiffusionEngine): height = getattr(prompt, 'height', 1024) or 1024 is_negative_prompt = getattr(prompt, 'is_negative_prompt', False) - crop_w = 0 - crop_h = 0 + crop_w = opts.sdxl_crop_left + crop_h = opts.sdxl_crop_top target_width = width target_height = height @@ -150,3 +152,121 @@ class StableDiffusionXL(ForgeDiffusionEngine): ) sf.save_file(sd, filename) return filename + + +class StableDiffusionXLRefiner(ForgeDiffusionEngine): + matched_guesses = [model_list.SDXLRefiner] + + def __init__(self, estimated_config, huggingface_components): + super().__init__(estimated_config, huggingface_components) + + clip = CLIP( + model_dict={ + 'clip_g': huggingface_components['text_encoder'] + }, + tokenizer_dict={ + 'clip_g': huggingface_components['tokenizer'], + } + ) + + vae = VAE(model=huggingface_components['vae']) + + unet = UnetPatcher.from_model( + model=huggingface_components['unet'], + diffusers_scheduler=huggingface_components['scheduler'], + config=estimated_config + ) + + self.text_processing_engine_g = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_g, + tokenizer=clip.tokenizer.clip_g, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_g', + embedding_expected_shape=2048, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=True, + minimal_clip_skip=2, + clip_skip=2, + return_pooled=True, + final_layer_norm=False, + ) + + self.embedder = Timestep(256) + + self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) + self.forge_objects_original = self.forge_objects.shallow_copy() + self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() + + # WebUI Legacy + self.is_sdxl = True + + def set_clip_skip(self, clip_skip): + self.text_processing_engine_g.clip_skip = clip_skip + + @torch.inference_mode() + def get_learned_conditioning(self, prompt: list[str]): + memory_management.load_model_gpu(self.forge_objects.clip.patcher) + + cond_g, clip_pooled = self.text_processing_engine_g(prompt) + + width = getattr(prompt, 'width', 1024) or 1024 + height = getattr(prompt, 'height', 1024) or 1024 + is_negative_prompt = getattr(prompt, 'is_negative_prompt', False) + + crop_w = opts.sdxl_crop_left + crop_h = opts.sdxl_crop_top + aesthetic = opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else opts.sdxl_refiner_high_aesthetic_score + + out = [ + self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])), + self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), + self.embedder(torch.Tensor([aesthetic])) + ] + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled) + + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt) + + if force_zero_negative_prompt: + clip_pooled = torch.zeros_like(clip_pooled) + cond_g = torch.zeros_like(cond_g) + + cond = dict( + crossattn=cond_g, + vector=torch.cat([clip_pooled, flat], dim=1), + ) + + return cond + + @torch.inference_mode() + def get_prompt_lengths_on_ui(self, prompt): + _, token_count = self.text_processing_engine_g.process_texts([prompt]) + return token_count, self.text_processing_engine_g.get_target_prompt_token_count(token_count) + + @torch.inference_mode() + def encode_first_stage(self, x): + sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) + sample = self.forge_objects.vae.first_stage_model.process_in(sample) + return sample.to(x) + + @torch.inference_mode() + def decode_first_stage(self, x): + sample = self.forge_objects.vae.first_stage_model.process_out(x) + sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + return sample.to(x) + + def save_checkpoint(self, filename): + sd = {} + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.') + ) + sd.update( + model_list.SDXLRefiner.process_clip_state_dict_for_saving(self, + utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='') + ) + ) + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.') + ) + sf.save_file(sd, filename) + return filename diff --git a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/model_index.json b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/model_index.json index d1024c84..472406db 100644 --- a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/model_index.json +++ b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/model_index.json @@ -9,18 +9,10 @@ "EulerDiscreteScheduler" ], "text_encoder": [ - null, - null - ], - "text_encoder_2": [ "transformers", "CLIPTextModelWithProjection" ], "tokenizer": [ - null, - null - ], - "tokenizer_2": [ "transformers", "CLIPTokenizer" ], diff --git a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/text_encoder_2/config.json b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/text_encoder/config.json similarity index 100% rename from backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/text_encoder_2/config.json rename to backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/text_encoder/config.json diff --git a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/merges.txt b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/merges.txt similarity index 100% rename from backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/merges.txt rename to backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/merges.txt diff --git a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/special_tokens_map.json b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/special_tokens_map.json similarity index 100% rename from backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/special_tokens_map.json rename to backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/special_tokens_map.json diff --git a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/tokenizer_config.json b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/tokenizer_config.json similarity index 100% rename from backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/tokenizer_config.json rename to backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/tokenizer_config.json diff --git a/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/vocab.json b/backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/vocab.json similarity index 100% rename from backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer_2/vocab.json rename to backend/huggingface/stabilityai/stable-diffusion-xl-refiner-1.0/tokenizer/vocab.json diff --git a/backend/loader.py b/backend/loader.py index baf0d347..50803245 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -19,11 +19,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel from backend.diffusion_engine.sd15 import StableDiffusion from backend.diffusion_engine.sd20 import StableDiffusion2 -from backend.diffusion_engine.sdxl import StableDiffusionXL +from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner from backend.diffusion_engine.flux import Flux -possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXL, Flux] +possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux] logging.getLogger("diffusers").setLevel(logging.ERROR) @@ -222,6 +222,8 @@ def replace_state_dict(sd, asd, guess): model_type = "sd1" case 1024: model_type = "sd2" + case 1280: + model_type = "xlrf" # sdxl refiner model case 2048: model_type = "sdxl" elif flux_test_key in sd: @@ -234,6 +236,7 @@ def replace_state_dict(sd, asd, guess): "-" : None, "sd1" : "cond_stage_model.transformer.", "sd2" : None, + "xlrf": None, "sdxl": "conditioner.embedders.0.transformer.", "flux": "text_encoders.clip_l.transformer.", "sd3" : "text_encoders.clip_l.transformer.", @@ -243,6 +246,7 @@ def replace_state_dict(sd, asd, guess): "-" : None, "sd1" : None, "sd2" : None, + "xlrf": "conditioner.embedders.0.model.", "sdxl": "conditioner.embedders.1.model.", "flux": None, "sd3" : "text_encoders.clip_g.", @@ -252,6 +256,7 @@ def replace_state_dict(sd, asd, guess): "-" : None, "sd1" : None, "sd2" : "conditioner.embedders.0.model.", + "xlrf": None, "sdxl": None, "flux": None, "sd3" : None, @@ -261,7 +266,7 @@ def replace_state_dict(sd, asd, guess): ## VAE format 0 (extracted from model, could be sd1, sd2, sdxl, sd3). if "first_stage_model.decoder.conv_in.weight" in asd: channels = asd["first_stage_model.decoder.conv_in.weight"].shape[1] - if model_type == "sd1" or model_type == "sd2" or model_type == "sdxl": + if model_type == "sd1" or model_type == "sd2" or model_type == "xlrf" or model_type == "sdxl": if channels == 4: for k, v in asd.items(): sd[k] = v diff --git a/modules/shared_options.py b/modules/shared_options.py index ca96341e..51741126 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -188,10 +188,10 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), { - "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), - "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), - "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), - "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), + "sdxl_crop_top": OptionInfo(0, "crop top coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate", gr.Number, {"minimum": 0, "maximum": 1024, "step": 1}), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Slider, {"minimum": 0, "maximum": 10, "step": 0.1}).info("used for refiner model prompt"), })) options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {