Add SDXL refiner model (#2686)

add sdxlrefiner
adjust some Settings
custom CLIP-G support
This commit is contained in:
DenOfEquity
2025-02-25 10:49:47 +00:00
committed by GitHub
parent c4b6fccefc
commit 8dd92501e6
9 changed files with 134 additions and 17 deletions

View File

@@ -13,6 +13,8 @@ from backend.nn.unet import Timestep
import safetensors.torch as sf
from backend import utils
from modules.shared import opts
class StableDiffusionXL(ForgeDiffusionEngine):
matched_guesses = [model_list.SDXL]
@@ -91,8 +93,8 @@ class StableDiffusionXL(ForgeDiffusionEngine):
height = getattr(prompt, 'height', 1024) or 1024
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
crop_w = 0
crop_h = 0
crop_w = opts.sdxl_crop_left
crop_h = opts.sdxl_crop_top
target_width = width
target_height = height
@@ -150,3 +152,121 @@ class StableDiffusionXL(ForgeDiffusionEngine):
)
sf.save_file(sd, filename)
return filename
class StableDiffusionXLRefiner(ForgeDiffusionEngine):
matched_guesses = [model_list.SDXLRefiner]
def __init__(self, estimated_config, huggingface_components):
super().__init__(estimated_config, huggingface_components)
clip = CLIP(
model_dict={
'clip_g': huggingface_components['text_encoder']
},
tokenizer_dict={
'clip_g': huggingface_components['tokenizer'],
}
)
vae = VAE(model=huggingface_components['vae'])
unet = UnetPatcher.from_model(
model=huggingface_components['unet'],
diffusers_scheduler=huggingface_components['scheduler'],
config=estimated_config
)
self.text_processing_engine_g = ClassicTextProcessingEngine(
text_encoder=clip.cond_stage_model.clip_g,
tokenizer=clip.tokenizer.clip_g,
embedding_dir=dynamic_args['embedding_dir'],
embedding_key='clip_g',
embedding_expected_shape=2048,
emphasis_name=dynamic_args['emphasis_name'],
text_projection=True,
minimal_clip_skip=2,
clip_skip=2,
return_pooled=True,
final_layer_norm=False,
)
self.embedder = Timestep(256)
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
self.forge_objects_original = self.forge_objects.shallow_copy()
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
# WebUI Legacy
self.is_sdxl = True
def set_clip_skip(self, clip_skip):
self.text_processing_engine_g.clip_skip = clip_skip
@torch.inference_mode()
def get_learned_conditioning(self, prompt: list[str]):
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
cond_g, clip_pooled = self.text_processing_engine_g(prompt)
width = getattr(prompt, 'width', 1024) or 1024
height = getattr(prompt, 'height', 1024) or 1024
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
crop_w = opts.sdxl_crop_left
crop_h = opts.sdxl_crop_top
aesthetic = opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else opts.sdxl_refiner_high_aesthetic_score
out = [
self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])),
self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])),
self.embedder(torch.Tensor([aesthetic]))
]
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1).to(clip_pooled)
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
if force_zero_negative_prompt:
clip_pooled = torch.zeros_like(clip_pooled)
cond_g = torch.zeros_like(cond_g)
cond = dict(
crossattn=cond_g,
vector=torch.cat([clip_pooled, flat], dim=1),
)
return cond
@torch.inference_mode()
def get_prompt_lengths_on_ui(self, prompt):
_, token_count = self.text_processing_engine_g.process_texts([prompt])
return token_count, self.text_processing_engine_g.get_target_prompt_token_count(token_count)
@torch.inference_mode()
def encode_first_stage(self, x):
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
return sample.to(x)
@torch.inference_mode()
def decode_first_stage(self, x):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SDXLRefiner.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename