mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 23:09:15 +00:00
Added pixart sigma support, but it wont work until i address breaking changes with lora code in diffusers so it can be upgraded.
This commit is contained in:
@@ -46,6 +46,7 @@ from diffusers import \
|
||||
UNet2DConditionModel
|
||||
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler
|
||||
from transformers import T5EncoderModel
|
||||
from toolkit.util.pixart_sigma_patch import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline
|
||||
|
||||
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
@@ -237,15 +238,22 @@ class StableDiffusion:
|
||||
elif self.model_config.is_pixart:
|
||||
te_kwargs = {}
|
||||
# handle quantization of TE
|
||||
te_is_quantized = False
|
||||
if self.model_config.text_encoder_bits == 8:
|
||||
te_kwargs['load_in_8bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
elif self.model_config.text_encoder_bits == 4:
|
||||
te_kwargs['load_in_4bit'] = True
|
||||
te_kwargs['device_map'] = "auto"
|
||||
te_is_quantized = True
|
||||
|
||||
main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
|
||||
if self.model_config.is_pixart_sigma:
|
||||
main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
|
||||
# load the TE in 8bit mode
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
main_model_path,
|
||||
subfolder="text_encoder",
|
||||
torch_dtype=self.torch_dtype,
|
||||
**te_kwargs
|
||||
@@ -256,20 +264,46 @@ class StableDiffusion:
|
||||
# check if it is just the unet
|
||||
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
|
||||
subfolder = None
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder)
|
||||
|
||||
if te_is_quantized:
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
if self.model_config.is_pixart_sigma:
|
||||
# tmp patches for diffusers PixArtSigmaPipeline Implementation
|
||||
print(
|
||||
"Changing _init_patched_inputs method of diffusers.models.Transformer2DModel "
|
||||
"using scripts.diffusers_patches.pixart_sigma_init_patched_inputs")
|
||||
setattr(Transformer2DModel, '_init_patched_inputs', pixart_sigma_init_patched_inputs)
|
||||
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self.torch_dtype,
|
||||
subfolder='transformer'
|
||||
)
|
||||
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
|
||||
main_model_path,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
**load_args
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype,
|
||||
subfolder=subfolder)
|
||||
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
|
||||
main_model_path,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
**load_args
|
||||
).to(self.device_torch)
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
@@ -1282,7 +1316,7 @@ class StableDiffusion:
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
truncate=not long_prompts,
|
||||
max_length=max_length,
|
||||
max_length=300 if self.model_config.is_pixart_sigma else 120,
|
||||
dropout_prob=dropout_prob
|
||||
)
|
||||
return PromptEmbeds(
|
||||
|
||||
Reference in New Issue
Block a user