From e3ad534de8b38e56ec79e79b0ef1d2aa9e46be16 Mon Sep 17 00:00:00 2001 From: Grae Date: Fri, 25 Oct 2024 18:18:43 -0600 Subject: [PATCH] Update model_list.py adding SD3.5 (does hijack SD3) --- huggingface_guess/model_list.py | 36 ++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/huggingface_guess/model_list.py b/huggingface_guess/model_list.py index cca4b1c..52cb322 100644 --- a/huggingface_guess/model_list.py +++ b/huggingface_guess/model_list.py @@ -587,7 +587,41 @@ class SD3(BASE): result['t5xxl'] = 'text_encoder_3' return result + +class SD35(BASE): + huggingface_repo = "stabilityai/stable-diffusion-3.5-large" + + unet_config = { + "in_channels": 16, + "pos_embed_scaling_factor": None, + } + sampling_settings = { + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent.SD3 + + memory_usage_factor = 1.2 + + text_encoder_key_prefix = ["text_encoders."] + unet_target = 'transformer' + + def clip_target(self, state_dict={}): + result = {} + pref = self.text_encoder_key_prefix[0] + + if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + result['clip_l'] = 'text_encoder' + + if "{}clip_g.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: + result['clip_g'] = 'text_encoder_2' + + if "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) in state_dict: + result['t5xxl'] = 'text_encoder_3' + + return result class StableAudio(BASE): unet_config = { @@ -713,6 +747,6 @@ class FluxSchnell(Flux): } -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD35, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell] models += [SVD_img2vid]