Added pixart sigma support, but it wont work until i address breaking changes with lora code in diffusers so it can be upgraded.

This commit is contained in:
Jaret Burkett
2024-04-20 10:46:56 -06:00
parent 377b81ee3e
commit 5a70b7f38d
5 changed files with 603 additions and 18 deletions

View File

@@ -46,6 +46,7 @@ from diffusers import \
UNet2DConditionModel
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
from transformers import T5EncoderModel
from toolkit.util.pixart_sigma_patch import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
from toolkit.util.inverse_cfg import inverse_classifier_guidance
@@ -237,15 +238,22 @@ class StableDiffusion:
elif self.model_config.is_pixart:
te_kwargs = {}
# handle quantization of TE
te_is_quantized = False
if self.model_config.text_encoder_bits == 8:
te_kwargs['load_in_8bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
elif self.model_config.text_encoder_bits == 4:
te_kwargs['load_in_4bit'] = True
te_kwargs['device_map'] = "auto"
te_is_quantized = True
main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
if self.model_config.is_pixart_sigma:
main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
# load the TE in 8bit mode
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
main_model_path,
subfolder="text_encoder",
torch_dtype=self.torch_dtype,
**te_kwargs
@@ -256,20 +264,46 @@ class StableDiffusion:
# check if it is just the unet
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
subfolder = None
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder)
if te_is_quantized:
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
**load_args
).to(self.device_torch)
if self.model_config.is_pixart_sigma:
# tmp patches for diffusers PixArtSigmaPipeline Implementation
print(
"Changing _init_patched_inputs method of diffusers.models.Transformer2DModel "
"using scripts.diffusers_patches.pixart_sigma_init_patched_inputs")
setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(
model_path,
torch_dtype=self.torch_dtype,
subfolder='transformer'
)
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
main_model_path,
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
**load_args
)
else:
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype,
subfolder=subfolder)
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
main_model_path,
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
**load_args
).to(self.device_torch)
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
flush()
@@ -1282,7 +1316,7 @@ class StableDiffusion:
self.text_encoder,
prompt,
truncate=not long_prompts,
max_length=max_length,
max_length=300 if self.model_config.is_pixart_sigma else 120,
dropout_prob=dropout_prob
)
return PromptEmbeds(