Work to omprove pixart training

This commit is contained in:
Jaret Burkett
2024-06-23 20:46:48 +00:00
parent 5d47244c57
commit 7165f2d25a
6 changed files with 65 additions and 18 deletions

View File

@@ -38,9 +38,9 @@ import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -185,7 +185,6 @@ class StableDiffusion:
}
if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
if self.custom_pipeline is not None:
pipln = self.custom_pipeline
@@ -331,11 +330,11 @@ class StableDiffusion:
if self.model_config.is_pixart_sigma:
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(
model_path,
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
torch_dtype=self.torch_dtype,
subfolder='transformer'
)
pipe: PixArtAlphaPipeline = PixArtSigmaPipeline.from_pretrained(
pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
main_model_path,
transformer=transformer,
text_encoder=text_encoder,
@@ -357,6 +356,9 @@ class StableDiffusion:
device=self.device_torch,
**load_args
).to(self.device_torch)
if self.model_config.unet_sample_size is not None:
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
flush()
@@ -592,7 +594,7 @@ class StableDiffusion:
**extra_args
)
elif self.is_pixart:
pipeline = PixArtAlphaPipeline(
pipeline = PixArtSigmaPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
@@ -1243,7 +1245,7 @@ class StableDiffusion:
elif self.pipeline.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
orig_height, orig_width = height, width
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)